{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import numpy as np # get rid of this eventually\n",
    "import argparse\n",
    "from jax import jit\n",
    "from jax.experimental.ode import odeint\n",
    "from functools import partial # reduces arguments to function by making some subset implicit\n",
    "\n",
    "from jax.experimental import stax\n",
    "from jax.experimental import optimizers\n",
    "\n",
    "import os, sys, time\n",
    "sys.path.append('..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "sys.path.append('../experiment_dblpend/')\n",
    "\n",
    "from lnn import lagrangian_eom_rk4, lagrangian_eom, unconstrained_eom, raw_lagrangian_eom\n",
    "from data import get_dataset\n",
    "from models import mlp as make_mlp\n",
    "from utils import wrap_coords"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "sys.path.append('../hyperopt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from HyperparameterSearch import learned_dynamics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from HyperparameterSearch import extended_mlp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ObjectView(object):\n",
    "    def __init__(self, d): self.__dict__ = d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from data import get_trajectory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from data import get_trajectory_analytic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from physics import analytical_fn\n",
    "\n",
    "vfnc = jax.jit(jax.vmap(analytical_fn))\n",
    "vget = partial(jax.jit, backend='cpu')(jax.vmap(partial(get_trajectory_analytic, mxsteps=100), (0, None), 0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 0.29830917716026306 {'act': [4],\n",
    "# 'batch_size': [27.0], 'dt': [0.09609870774790222],\n",
    "# 'hidden_dim': [596.0], 'l2reg': [0.24927677946969878],\n",
    "# 'layers': [4.0], 'lr': [0.005516656601005163],\n",
    "# 'lr2': [1.897157209816416e-05], 'n_updates': [4.0]}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Now, let's load the best model. To generate more models, see the code below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle as pkl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# loaded = pkl.load(open('./params_for_loss_0.29429444670677185_nupdates=1.pkl', 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = ObjectView({'dataset_size': 200,\n",
    " 'fps': 10,\n",
    " 'samples': 100,\n",
    " 'num_epochs': 80000,\n",
    " 'seed': 0,\n",
    " 'loss': 'l1',\n",
    " 'act': 'relu_relu',\n",
    " 'hidden_dim': 600,\n",
    " 'output_dim': 2,\n",
    " 'layers': 3,\n",
    " 'n_updates': 1,\n",
    " 'lr': 0.001,\n",
    " 'lr2': 2e-05,\n",
    " 'dt': 0.1,\n",
    " 'model': 'gln',\n",
    " 'batch_size': 512,\n",
    " 'l2reg': 5.7e-07,\n",
    "})\n",
    "# args = loaded['args']\n",
    "rng = jax.random.PRNGKey(args.seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax.experimental.ode import odeint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "from HyperparameterSearch import new_get_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "vfnc = jax.jit(jax.vmap(analytical_fn, 0, 0))\n",
    "vget = partial(jax.jit, backend='cpu')(jax.vmap(partial(get_trajectory_analytic, mxsteps=100), (0, None), 0))\n",
    "minibatch_per = 2000\n",
    "batch = 512\n",
    "\n",
    "@jax.jit\n",
    "def get_derivative_dataset(rng):\n",
    "    # randomly sample inputs\n",
    "\n",
    "    y0 = jnp.concatenate([\n",
    "        jax.random.uniform(rng, (batch*minibatch_per, 2))*2.0*np.pi,\n",
    "        (jax.random.uniform(rng+1, (batch*minibatch_per, 2))-0.5)*10*2\n",
    "    ], axis=1)\n",
    "    \n",
    "    return y0, vfnc(y0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_params = None\n",
    "best_loss = np.inf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import product"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# @jax.jit\n",
    "def baseline_eom(baseline, state, t=None):\n",
    "    q, q_t = jnp.split(state, 2)\n",
    "    q = q % (2*jnp.pi)\n",
    "    q_tt = baseline(q, q_t)\n",
    "    return jnp.concatenate([q_t, q_tt])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4e-05\n"
     ]
    }
   ],
   "source": [
    "init_random_params, nn_forward_fn = extended_mlp(args)\n",
    "import HyperparameterSearch\n",
    "HyperparameterSearch.nn_forward_fn = nn_forward_fn\n",
    "_, init_params = init_random_params(rng+1, (-1, 4))\n",
    "rng += 1\n",
    "model = (nn_forward_fn, init_params)\n",
    "opt_init, opt_update, get_params = optimizers.adam(args.lr)\n",
    "opt_state = opt_init([[l2/200.0 for l2 in l1] for l1 in init_params])\n",
    "from jax.tree_util import tree_flatten\n",
    "from HyperparameterSearch import make_loss, train\n",
    "from copy import deepcopy as copy\n",
    "# train(args, model, data, rng);\n",
    "from jax.tree_util import tree_flatten\n",
    "\n",
    "@jax.jit\n",
    "def loss(params, batch, l2reg):\n",
    "    state, targets = batch#_rk4\n",
    "    leaves, _ = tree_flatten(params)\n",
    "    l2_norm = sum(jnp.vdot(param, param) for param in leaves)\n",
    "    preds = jax.vmap(\n",
    "        partial(\n",
    "            baseline_eom,\n",
    "            learned_dynamics(params)))(state)\n",
    "    return jnp.sum(jnp.abs(preds - targets)) + l2reg*l2_norm/args.batch_size\n",
    "\n",
    "# @jax.jit\n",
    "# def normalize_param_update(param_update):\n",
    "#     new_params = []\n",
    "#     num_weights = args.hidden_dim**2*3\n",
    "#     gradient_norm = sum([jnp.sum(l2**2)\n",
    "#                          for l1 in param_update\n",
    "#                          for l2 in l1\n",
    "#                          if len(l1) != 0])/num_weights\n",
    "# #     gradient_norm = 1 + \n",
    "#     for l1 in param_update:\n",
    "#         if (len(l1)) == 0: new_params.append(()); continue\n",
    "#         new_l1 = []\n",
    "#         for l2 in l1:\n",
    "#             new_l1.append(\n",
    "#                 l2/gradient_norm\n",
    "#             )\n",
    "\n",
    "#         new_params.append(new_l1)\n",
    "        \n",
    "#     return new_params\n",
    "\n",
    "@jax.jit\n",
    "def update_derivative(i, opt_state, batch, l2reg):\n",
    "    params = get_params(opt_state)\n",
    "    param_update = jax.grad(\n",
    "            lambda *args: loss(*args)/len(batch),\n",
    "            0\n",
    "        )(params, batch, l2reg)\n",
    "#     param_update = normalize_param_update(param_update)\n",
    "    params = get_params(opt_state)\n",
    "    return opt_update(i, param_update, opt_state), params\n",
    "\n",
    "\n",
    "best_small_loss = np.inf\n",
    "(nn_forward_fn, init_params) = model\n",
    "iteration = 0\n",
    "total_epochs = 300\n",
    "minibatch_per = 2000\n",
    "train_losses, test_losses = [], []\n",
    "\n",
    "lr = 4e-5 #1e-3\n",
    "import math\n",
    "\n",
    "final_div_factor=1e4\n",
    "\n",
    "#OneCycleLR:\n",
    "@jax.jit\n",
    "def OneCycleLR(pct):\n",
    "    #Rush it:\n",
    "    start = 0.2 #0.2\n",
    "    pct = pct * (1-start) + start\n",
    "    high, low = lr, lr/final_div_factor\n",
    "    \n",
    "    scale = 1.0 - (jnp.cos(2 * jnp.pi * pct) + 1)/2\n",
    "    \n",
    "    return low + (high - low)*scale\n",
    "    \n",
    "from lnn import custom_init\n",
    "\n",
    "opt_init, opt_update, get_params = optimizers.adam(\n",
    "    OneCycleLR\n",
    ")\n",
    "\n",
    "# init_params = custom_init(init_params, seed=0)\n",
    "# init_params = \n",
    "opt_state = opt_init(init_params)\n",
    "# opt_state = opt_init(best_params)\n",
    "bad_iterations = 0\n",
    "print(lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Idea: add identity before inverse:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Let's train it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = jax.random.PRNGKey(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10, 4)\n"
     ]
    }
   ],
   "source": [
    "batch_data = get_derivative_dataset(rng)[0][:10], get_derivative_dataset(rng)[1][:10]\n",
    "print(batch_data[0].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeviceArray(329.27036, dtype=float32)"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss(get_params(opt_state), batch_data, 0.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt_state, params = update_derivative(0.0, opt_state, batch_data, 0.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.notebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "# best_loss = np.inf\n",
    "# best_params = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0b5590ea11304e22a0babf5b5cdc133c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0, max=300.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch=0 lr=1.414164034940768e-05 loss=0.08233094215393066\n",
      "epoch=1 lr=1.4462808394455351e-05 loss=0.08085848391056061\n",
      "epoch=2 lr=1.478552985645365e-05 loss=0.08081990480422974\n",
      "epoch=3 lr=1.5109715604921803e-05 loss=0.08108066767454147\n",
      "epoch=4 lr=1.5435276509379037e-05 loss=0.08120445907115936\n",
      "epoch=5 lr=1.5762116163386963e-05 loss=0.08121983706951141\n",
      "epoch=6 lr=1.609015089343302e-05 loss=0.08122540265321732\n",
      "epoch=7 lr=1.641928065510001e-05 loss=0.08126170188188553\n",
      "epoch=8 lr=1.6749418136896566e-05 loss=0.08138350397348404\n",
      "epoch=9 lr=1.70804687513737e-05 loss=0.0815456435084343\n",
      "epoch=10 lr=1.741233791108243e-05 loss=0.08131673187017441\n",
      "epoch=11 lr=1.774493466655258e-05 loss=0.08151450008153915\n",
      "epoch=12 lr=1.807816261134576e-05 loss=0.08159146457910538\n",
      "epoch=13 lr=1.84119344339706e-05 loss=0.08152744919061661\n",
      "epoch=14 lr=1.8746148271020502e-05 loss=0.08140755444765091\n",
      "epoch=15 lr=1.90807186299935e-05 loss=0.08162107318639755\n",
      "epoch=16 lr=1.94155472854618e-05 loss=0.08152826130390167\n",
      "epoch=17 lr=1.9750539649976417e-05 loss=0.08166711032390594\n",
      "epoch=18 lr=2.008560113608837e-05 loss=0.08173049986362457\n",
      "epoch=19 lr=2.0420640794327483e-05 loss=0.08169645071029663\n",
      "epoch=20 lr=2.0755562218255363e-05 loss=0.0818859338760376\n",
      "epoch=21 lr=2.1090272639412433e-05 loss=0.08188760280609131\n",
      "epoch=22 lr=2.1424677470349707e-05 loss=0.08178362250328064\n",
      "epoch=23 lr=2.1758683942607604e-05 loss=0.08194532245397568\n",
      "epoch=24 lr=2.2092195649747737e-05 loss=0.0819844901561737\n",
      "epoch=25 lr=2.2425121642299928e-05 loss=0.08199925720691681\n",
      "epoch=26 lr=2.2757367332815193e-05 loss=0.08223752677440643\n",
      "epoch=27 lr=2.3088836314855143e-05 loss=0.08178704231977463\n",
      "epoch=28 lr=2.3419443095917813e-05 loss=0.08195152878761292\n",
      "epoch=29 lr=2.374908945057541e-05 loss=0.0821177288889885\n",
      "epoch=30 lr=2.4077682610368356e-05 loss=0.08227823674678802\n",
      "epoch=31 lr=2.440513162582647e-05 loss=0.08218372613191605\n",
      "epoch=32 lr=2.473134736646898e-05 loss=0.08209142833948135\n",
      "epoch=33 lr=2.5056231606868096e-05 loss=0.08230350911617279\n",
      "epoch=34 lr=2.5379698854521848e-05 loss=0.08229774236679077\n",
      "epoch=35 lr=2.5701656340970658e-05 loss=0.0824241116642952\n",
      "epoch=36 lr=2.6022013116744347e-05 loss=0.08223366737365723\n",
      "epoch=37 lr=2.6340680051362142e-05 loss=0.0823461040854454\n",
      "epoch=38 lr=2.6657566195353866e-05 loss=0.08237158507108688\n",
      "epoch=39 lr=2.6972587875206955e-05 loss=0.08233390003442764\n",
      "epoch=40 lr=2.7285650503472425e-05 loss=0.08233213424682617\n",
      "epoch=41 lr=2.759666858764831e-05 loss=0.08274270594120026\n",
      "epoch=42 lr=2.7905552997253835e-05 loss=0.08241263777017593\n",
      "epoch=43 lr=2.8212221877765842e-05 loss=0.08232994377613068\n",
      "epoch=44 lr=2.8516586098703556e-05 loss=0.08251188695430756\n",
      "epoch=45 lr=2.8818558348575607e-05 loss=0.0825469121336937\n",
      "epoch=46 lr=2.911805313488003e-05 loss=0.08245673775672913\n",
      "epoch=47 lr=2.941498860309366e-05 loss=0.08256614953279495\n",
      "epoch=48 lr=2.9709288355661556e-05 loss=0.08250924199819565\n",
      "epoch=49 lr=3.000085780513473e-05 loss=0.08281709253787994\n",
      "epoch=50 lr=3.028962055395823e-05 loss=0.0824875682592392\n",
      "epoch=51 lr=3.05754947476089e-05 loss=0.08259851485490799\n",
      "epoch=52 lr=3.0858405807521194e-05 loss=0.0827212706208229\n",
      "epoch=53 lr=3.113826460321434e-05 loss=0.08251755684614182\n",
      "epoch=54 lr=3.1414994737133384e-05 loss=0.08251798152923584\n",
      "epoch=55 lr=3.168852344970219e-05 loss=0.08277525007724762\n",
      "epoch=56 lr=3.195877798134461e-05 loss=0.08270535618066788\n",
      "epoch=57 lr=3.2225663744611666e-05 loss=0.08271445333957672\n",
      "epoch=58 lr=3.248912253184244e-05 loss=0.08273577690124512\n",
      "epoch=59 lr=3.2749077945481986e-05 loss=0.0828409418463707\n",
      "epoch=60 lr=3.300545722595416e-05 loss=0.08270175009965897\n",
      "epoch=61 lr=3.3258183975704014e-05 loss=0.08270849287509918\n",
      "epoch=62 lr=3.3507185435155407e-05 loss=0.08290085196495056\n",
      "epoch=63 lr=3.375239612068981e-05 loss=0.08259326964616776\n",
      "epoch=64 lr=3.3993750548688695e-05 loss=0.08262861520051956\n",
      "epoch=65 lr=3.423117959755473e-05 loss=0.08276861160993576\n",
      "epoch=66 lr=3.446460686973296e-05 loss=0.08297166228294373\n",
      "epoch=67 lr=3.469397415756248e-05 loss=0.08272399008274078\n",
      "epoch=68 lr=3.491922325338237e-05 loss=0.08280134946107864\n",
      "epoch=69 lr=3.514028139761649e-05 loss=0.08285874128341675\n",
      "epoch=70 lr=3.535709038260393e-05 loss=0.08266808837652206\n",
      "epoch=71 lr=3.5569584724726155e-05 loss=0.08307917416095734\n",
      "epoch=72 lr=3.577771713025868e-05 loss=0.08291200548410416\n",
      "epoch=73 lr=3.598141483962536e-05 loss=0.08291705697774887\n",
      "epoch=74 lr=3.6180626921122894e-05 loss=0.08281207084655762\n",
      "epoch=75 lr=3.637529516709037e-05 loss=0.08293124288320541\n",
      "epoch=76 lr=3.65653722838033e-05 loss=0.0830584317445755\n",
      "epoch=77 lr=3.6750796425621957e-05 loss=0.08283168822526932\n",
      "epoch=78 lr=3.693152029882185e-05 loss=0.08262655884027481\n",
      "epoch=79 lr=3.710748933372088e-05 loss=0.08275551348924637\n",
      "epoch=80 lr=3.727865987457335e-05 loss=0.08278914541006088\n",
      "epoch=81 lr=3.744497735169716e-05 loss=0.08292493224143982\n",
      "epoch=82 lr=3.7606401747325435e-05 loss=0.08285250514745712\n",
      "epoch=83 lr=3.7762878491776064e-05 loss=0.08280224353075027\n",
      "epoch=84 lr=3.791437120526098e-05 loss=0.08306962251663208\n",
      "epoch=85 lr=3.806083623203449e-05 loss=0.08288473635911942\n",
      "epoch=86 lr=3.820223355432972e-05 loss=0.08278044313192368\n",
      "epoch=87 lr=3.833851951640099e-05 loss=0.08275149017572403\n",
      "epoch=88 lr=3.846965773846023e-05 loss=0.08290810137987137\n",
      "epoch=89 lr=3.859561184071936e-05 loss=0.08262310177087784\n",
      "epoch=90 lr=3.871634544339031e-05 loss=0.08283252269029617\n",
      "epoch=91 lr=3.883182580466382e-05 loss=0.08289124816656113\n",
      "epoch=92 lr=3.894202018273063e-05 loss=0.08299709111452103\n",
      "epoch=93 lr=3.904689947376028e-05 loss=0.08299640566110611\n",
      "epoch=94 lr=3.91464309359435e-05 loss=0.08293475955724716\n",
      "epoch=95 lr=3.924058546544984e-05 loss=0.08277945220470428\n",
      "epoch=96 lr=3.9329341234406456e-05 loss=0.0829160287976265\n",
      "epoch=97 lr=3.94126727769617e-05 loss=0.08263184130191803\n",
      "epoch=98 lr=3.949055462726392e-05 loss=0.08291900157928467\n",
      "epoch=99 lr=3.9562961319461465e-05 loss=0.08261646330356598\n",
      "epoch=100 lr=3.962988193961792e-05 loss=0.0826425701379776\n",
      "epoch=101 lr=3.9691291021881625e-05 loss=0.08270931988954544\n",
      "epoch=102 lr=3.974717401433736e-05 loss=0.0826399028301239\n",
      "epoch=103 lr=3.979750908911228e-05 loss=0.0827610194683075\n",
      "epoch=104 lr=3.9842288970248774e-05 loss=0.08254894614219666\n",
      "epoch=105 lr=3.9881502743810415e-05 loss=0.08255968242883682\n",
      "epoch=106 lr=3.991513585788198e-05 loss=0.08260182291269302\n",
      "epoch=107 lr=3.9943173760548234e-05 loss=0.0826568603515625\n",
      "epoch=108 lr=3.996561645180918e-05 loss=0.08253723382949829\n",
      "epoch=109 lr=3.99824530177284e-05 loss=0.08240795880556107\n",
      "epoch=110 lr=3.9993683458305895e-05 loss=0.08251143991947174\n",
      "epoch=111 lr=3.9999300497584045e-05 loss=0.08235416561365128\n",
      "epoch=112 lr=3.9999300497584045e-05 loss=0.08260670304298401\n",
      "epoch=113 lr=3.999369073426351e-05 loss=0.08246292173862457\n",
      "epoch=114 lr=3.998246756964363e-05 loss=0.0821480080485344\n",
      "epoch=115 lr=3.996563464170322e-05 loss=0.08246731013059616\n",
      "epoch=116 lr=3.9943199226399884e-05 loss=0.08224262297153473\n",
      "epoch=117 lr=3.991516132373363e-05 loss=0.08217307180166245\n",
      "epoch=118 lr=3.9881539123598486e-05 loss=0.08221583068370819\n",
      "epoch=119 lr=3.984233262599446e-05 loss=0.0822344571352005\n",
      "epoch=120 lr=3.9797556382836774e-05 loss=0.08230280876159668\n",
      "epoch=121 lr=3.974722494604066e-05 loss=0.08202328532934189\n",
      "epoch=122 lr=3.969134922954254e-05 loss=0.08202638477087021\n",
      "epoch=123 lr=3.9629947423236445e-05 loss=0.08190978318452835\n",
      "epoch=124 lr=3.95630304410588e-05 loss=0.08201901614665985\n",
      "epoch=125 lr=3.949062738684006e-05 loss=0.0819195881485939\n",
      "epoch=126 lr=3.941274917451665e-05 loss=0.08171315491199493\n",
      "epoch=127 lr=3.9329428545897827e-05 loss=0.0819551944732666\n",
      "epoch=128 lr=3.924067641492002e-05 loss=0.08176072686910629\n",
      "epoch=129 lr=3.9146525523392484e-05 loss=0.0819348394870758\n",
      "epoch=130 lr=3.9047001337166876e-05 loss=0.08192836493253708\n",
      "epoch=131 lr=3.8942125684116036e-05 loss=0.08173386007547379\n",
      "epoch=132 lr=3.883193858200684e-05 loss=0.081647127866745\n",
      "epoch=133 lr=3.8716465496690944e-05 loss=0.0815071165561676\n",
      "epoch=134 lr=3.859573189401999e-05 loss=0.08147221058607101\n",
      "epoch=135 lr=3.8469785067718476e-05 loss=0.08145289123058319\n",
      "epoch=136 lr=3.833865048363805e-05 loss=0.08150055259466171\n",
      "epoch=137 lr=3.8202368159545586e-05 loss=0.08177012950181961\n",
      "epoch=138 lr=3.8060978113207966e-05 loss=0.08160659670829773\n",
      "epoch=139 lr=3.791452036239207e-05 loss=0.0814320519566536\n",
      "epoch=140 lr=3.776303128688596e-05 loss=0.08120737224817276\n",
      "epoch=141 lr=3.760655818041414e-05 loss=0.08121152967214584\n",
      "epoch=142 lr=3.7445137422764674e-05 loss=0.08124125003814697\n",
      "epoch=143 lr=3.7278827221598476e-05 loss=0.08128488808870316\n",
      "epoch=144 lr=3.710766031872481e-05 loss=0.08108100295066833\n",
      "epoch=145 lr=3.693169492180459e-05 loss=0.08138883113861084\n",
      "epoch=146 lr=3.675097832456231e-05 loss=0.08106806874275208\n",
      "epoch=147 lr=3.656555782072246e-05 loss=0.08085615932941437\n",
      "epoch=148 lr=3.637548797996715e-05 loss=0.08084773272275925\n",
      "epoch=149 lr=3.618081973399967e-05 loss=0.08098548650741577\n",
      "epoch=150 lr=3.598161129048094e-05 loss=0.08072782307863235\n",
      "epoch=151 lr=3.577791721909307e-05 loss=0.0807228535413742\n",
      "epoch=152 lr=3.556979572749697e-05 loss=0.08078335970640182\n",
      "epoch=153 lr=3.535730138537474e-05 loss=0.08087658137083054\n",
      "epoch=154 lr=3.514050331432372e-05 loss=0.08048451691865921\n",
      "epoch=155 lr=3.49194451700896e-05 loss=0.08066511899232864\n",
      "epoch=156 lr=3.469419971224852e-05 loss=0.08069107681512833\n",
      "epoch=157 lr=3.446483606239781e-05 loss=0.08064887672662735\n",
      "epoch=158 lr=3.4231408790219575e-05 loss=0.08028585463762283\n",
      "epoch=159 lr=3.3993987017311156e-05 loss=0.0802907794713974\n",
      "epoch=160 lr=3.375263622729108e-05 loss=0.08054166287183762\n",
      "epoch=161 lr=3.3507425541756675e-05 loss=0.08038027584552765\n",
      "epoch=162 lr=3.325842408230528e-05 loss=0.08021032810211182\n",
      "epoch=163 lr=3.300570824649185e-05 loss=0.08042652904987335\n",
      "epoch=164 lr=3.274933624197729e-05 loss=0.0800519585609436\n",
      "epoch=165 lr=3.248938446631655e-05 loss=0.08018163591623306\n",
      "epoch=166 lr=3.222592204110697e-05 loss=0.07990071177482605\n",
      "epoch=167 lr=3.1959036277839914e-05 loss=0.07993071526288986\n",
      "epoch=168 lr=3.168879266013391e-05 loss=0.07978863269090652\n",
      "epoch=169 lr=3.1415267585543916e-05 loss=0.07996170222759247\n",
      "epoch=170 lr=3.113853745162487e-05 loss=0.07995752990245819\n",
      "epoch=171 lr=3.085868229391053e-05 loss=0.07990874350070953\n",
      "epoch=172 lr=3.057577850995585e-05 loss=0.07967771589756012\n",
      "epoch=173 lr=3.0289906135294586e-05 loss=0.07971809059381485\n",
      "epoch=174 lr=3.0001148843439296e-05 loss=0.0796566754579544\n",
      "epoch=175 lr=2.9709581212955527e-05 loss=0.07942468672990799\n",
      "epoch=176 lr=2.9415281460387632e-05 loss=0.07945834845304489\n",
      "epoch=177 lr=2.9118338716216385e-05 loss=0.07945135980844498\n",
      "epoch=178 lr=2.881885302485898e-05 loss=0.07944317907094955\n",
      "epoch=179 lr=2.8516884412965737e-05 loss=0.07929804921150208\n",
      "epoch=180 lr=2.8212525648996234e-05 loss=0.0791928619146347\n",
      "epoch=181 lr=2.7905860406463034e-05 loss=0.07925307005643845\n",
      "epoch=182 lr=2.75969705398893e-05 loss=0.07905494421720505\n",
      "epoch=183 lr=2.728595609369222e-05 loss=0.07901769876480103\n",
      "epoch=184 lr=2.6972897103405558e-05 loss=0.07903199642896652\n",
      "epoch=185 lr=2.6657879061531276e-05 loss=0.0788295567035675\n",
      "epoch=186 lr=2.6340994736528955e-05 loss=0.07907552272081375\n",
      "epoch=187 lr=2.6022331439889967e-05 loss=0.07879389077425003\n",
      "epoch=188 lr=2.5701976483105682e-05 loss=0.07876011729240417\n",
      "epoch=189 lr=2.5380020815646276e-05 loss=0.07867168635129929\n",
      "epoch=190 lr=2.5056555386981927e-05 loss=0.07860483974218369\n",
      "epoch=191 lr=2.4731671146582812e-05 loss=0.07866086065769196\n",
      "epoch=192 lr=2.4405460862908512e-05 loss=0.07848525047302246\n",
      "epoch=193 lr=2.4078004571492784e-05 loss=0.07862772047519684\n",
      "epoch=194 lr=2.3749413230689242e-05 loss=0.07842084765434265\n",
      "epoch=195 lr=2.3419768695021048e-05 loss=0.07825013250112534\n",
      "epoch=196 lr=2.308916373294778e-05 loss=0.07828468829393387\n",
      "epoch=197 lr=2.2757696569897234e-05 loss=0.07835362106561661\n",
      "epoch=198 lr=2.2425452698371373e-05 loss=0.0783260241150856\n",
      "epoch=199 lr=2.2092519429861568e-05 loss=0.07816227525472641\n",
      "epoch=200 lr=2.1759007722721435e-05 loss=0.07807103544473648\n",
      "epoch=201 lr=2.1425004888442345e-05 loss=0.07806695997714996\n",
      "epoch=202 lr=2.1090601876494475e-05 loss=0.07793515920639038\n",
      "epoch=203 lr=2.075589327432681e-05 loss=0.07788192480802536\n",
      "epoch=204 lr=2.0420971850398928e-05 loss=0.07784659415483475\n",
      "epoch=205 lr=2.008593401114922e-05 loss=0.07797537744045258\n",
      "epoch=206 lr=1.9750872525037266e-05 loss=0.07776566594839096\n",
      "epoch=207 lr=1.9415880160522647e-05 loss=0.07761611044406891\n",
      "epoch=208 lr=1.9081055143033154e-05 loss=0.07755095511674881\n",
      "epoch=209 lr=1.874647568911314e-05 loss=0.07758081704378128\n",
      "epoch=210 lr=1.8412260033073835e-05 loss=0.07750420272350311\n",
      "epoch=211 lr=1.8078491848427802e-05 loss=0.07733199000358582\n",
      "epoch=212 lr=1.774526208464522e-05 loss=0.07725705206394196\n",
      "epoch=213 lr=1.7412667148164473e-05 loss=0.07729509472846985\n",
      "epoch=214 lr=1.708079616946634e-05 loss=0.07720137387514114\n",
      "epoch=215 lr=1.674973827903159e-05 loss=0.07712603360414505\n",
      "epoch=216 lr=1.6419602616224438e-05 loss=0.07725665718317032\n",
      "epoch=217 lr=1.6090471035568044e-05 loss=0.07704689353704453\n",
      "epoch=218 lr=1.5762439943500794e-05 loss=0.07702665776014328\n",
      "epoch=219 lr=1.5435598470503464e-05 loss=0.07680519670248032\n",
      "epoch=220 lr=1.511003756604623e-05 loss=0.07694040983915329\n",
      "epoch=221 lr=1.4785851817578077e-05 loss=0.07689909636974335\n",
      "epoch=222 lr=1.4463129446085077e-05 loss=0.07682166248559952\n",
      "epoch=223 lr=1.4141962310532108e-05 loss=0.07668165117502213\n",
      "epoch=224 lr=1.3822439541399945e-05 loss=0.07656940817832947\n",
      "epoch=225 lr=1.3504642083717044e-05 loss=0.0767744928598404\n",
      "epoch=226 lr=1.318867907684762e-05 loss=0.0765053853392601\n",
      "epoch=227 lr=1.2874626918346621e-05 loss=0.07644669711589813\n",
      "epoch=228 lr=1.2562577467178926e-05 loss=0.07641707360744476\n",
      "epoch=229 lr=1.2252617125341203e-05 loss=0.07649111747741699\n",
      "epoch=230 lr=1.1944830475840718e-05 loss=0.07647287100553513\n",
      "epoch=231 lr=1.1639305739663541e-05 loss=0.07626831531524658\n",
      "epoch=232 lr=1.133612749981694e-05 loss=0.07616283744573593\n",
      "epoch=233 lr=1.1035375791834667e-05 loss=0.07624471932649612\n",
      "epoch=234 lr=1.0737148841144517e-05 loss=0.07614436000585556\n",
      "epoch=235 lr=1.0441522135806736e-05 loss=0.07615530490875244\n",
      "epoch=236 lr=1.0148580258828588e-05 loss=0.0760815218091011\n",
      "epoch=237 lr=9.858404155238532e-06 loss=0.07603650540113449\n",
      "epoch=238 lr=9.571076589054428e-06 loss=0.07588515430688858\n",
      "epoch=239 lr=9.286676686315332e-06 loss=0.07589379698038101\n",
      "epoch=240 lr=9.005284482554998e-06 loss=0.07570710778236389\n",
      "epoch=241 lr=8.726973646844272e-06 loss=0.0757644921541214\n",
      "epoch=242 lr=8.451832400169224e-06 loss=0.07566524296998978\n",
      "epoch=243 lr=8.179936230590101e-06 loss=0.07567209750413895\n",
      "epoch=244 lr=7.911357897683047e-06 loss=0.07563567161560059\n",
      "epoch=245 lr=7.646172889508307e-06 loss=0.07553929090499878\n",
      "epoch=246 lr=7.384458058368182e-06 loss=0.07560957968235016\n",
      "epoch=247 lr=7.126283890102059e-06 loss=0.0754794031381607\n",
      "epoch=248 lr=6.87172359903343e-06 loss=0.07540281862020493\n",
      "epoch=249 lr=6.62085085423314e-06 loss=0.07529501616954803\n",
      "epoch=250 lr=6.373735686793225e-06 loss=0.07521078735589981\n",
      "epoch=251 lr=6.1304372138693e-06 loss=0.07542195171117783\n",
      "epoch=252 lr=5.891041382710682e-06 loss=0.07515565305948257\n",
      "epoch=253 lr=5.655605036736233e-06 loss=0.07529058307409286\n",
      "epoch=254 lr=5.424200026027393e-06 loss=0.07519297301769257\n",
      "epoch=255 lr=5.196880920266267e-06 loss=0.07520414888858795\n",
      "epoch=256 lr=4.973724571755156e-06 loss=0.07520712167024612\n",
      "epoch=257 lr=4.754778274218552e-06 loss=0.07513530552387238\n",
      "epoch=258 lr=4.540117970464053e-06 loss=0.07497373223304749\n",
      "epoch=259 lr=4.329796411184361e-06 loss=0.07511777430772781\n",
      "epoch=260 lr=4.123878170503303e-06 loss=0.07504121214151382\n",
      "epoch=261 lr=3.922415544366231e-06 loss=0.07494932413101196\n",
      "epoch=262 lr=3.725468332049786e-06 loss=0.074896439909935\n",
      "epoch=263 lr=3.533090875862399e-06 loss=0.07485692948102951\n",
      "epoch=264 lr=3.345333425386343e-06 loss=0.07495750486850739\n",
      "epoch=265 lr=3.1622535061615054e-06 loss=0.07473168522119522\n",
      "epoch=266 lr=2.983900913022808e-06 loss=0.07485850155353546\n",
      "epoch=267 lr=2.8103208933316637e-06 loss=0.07471615076065063\n",
      "epoch=268 lr=2.6415707452542847e-06 loss=0.0748014822602272\n",
      "epoch=269 lr=2.477698217262514e-06 loss=0.07473327219486237\n",
      "epoch=270 lr=2.3187460556073347e-06 loss=0.07482635229825974\n",
      "epoch=271 lr=2.1647504127031425e-06 loss=0.07469252496957779\n",
      "epoch=272 lr=2.015765403484693e-06 loss=0.07458953559398651\n",
      "epoch=273 lr=1.8718275214268942e-06 loss=0.0746396854519844\n",
      "epoch=274 lr=1.7329816728306469e-06 loss=0.07482746243476868\n",
      "epoch=275 lr=1.5992660564734251e-06 loss=0.07458007335662842\n",
      "epoch=276 lr=1.4707140962855192e-06 loss=0.07449283450841904\n",
      "epoch=277 lr=1.34736865220475e-06 loss=0.07443839311599731\n",
      "epoch=278 lr=1.2292583733142237e-06 loss=0.07436899840831757\n",
      "epoch=279 lr=1.1164165698573925e-06 loss=0.07443805783987045\n",
      "epoch=280 lr=1.0088766657645465e-06 loss=0.07454211264848709\n",
      "epoch=281 lr=9.066674238056294e-07 loss=0.0743437334895134\n",
      "epoch=282 lr=8.098219268504181e-07 loss=0.07442692667245865\n",
      "epoch=283 lr=7.183593879744876e-07 loss=0.07442381232976913\n",
      "epoch=284 lr=6.323131742647092e-07 loss=0.07422773540019989\n",
      "epoch=285 lr=5.517071599570045e-07 loss=0.07437479496002197\n",
      "epoch=286 lr=4.765603591749823e-07 loss=0.07410585135221481\n",
      "epoch=287 lr=4.068918713073799e-07 loss=0.07422129809856415\n",
      "epoch=288 lr=3.4272554216840945e-07 loss=0.07428503781557083\n",
      "epoch=289 lr=2.840756678779144e-07 loss=0.07416900247335434\n",
      "epoch=290 lr=2.3096609425010683e-07 loss=0.07422371953725815\n",
      "epoch=291 lr=1.8340634255764598e-07 loss=0.07416673004627228\n",
      "epoch=292 lr=1.414083499184926e-07 loss=0.07419059425592422\n",
      "epoch=293 lr=1.0498401081804332e-07 loss=0.07413207739591599\n",
      "epoch=294 lr=7.41500372214432e-08 loss=0.07430032640695572\n",
      "epoch=295 lr=4.890880234142969e-08 loss=0.07413506507873535\n",
      "epoch=296 lr=2.9265075696116583e-08 loss=0.07415470480918884\n",
      "epoch=297 lr=1.5233160510774724e-08 loss=0.07415387779474258\n",
      "epoch=298 lr=6.808289931825584e-09 loss=0.0741974487900734\n",
      "epoch=299 lr=3.999999886872274e-09 loss=0.07420003414154053\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for epoch in tqdm(range(epoch, total_epochs)):\n",
    "    epoch_loss = 0.0\n",
    "    num_samples = 0\n",
    "    all_batch_data = get_derivative_dataset(rng)\n",
    "    for minibatch in range(minibatch_per):\n",
    "        fraction = (epoch + minibatch/minibatch_per)/total_epochs\n",
    "        batch_data = (all_batch_data[0][minibatch*batch:(minibatch+1)*batch], all_batch_data[1][minibatch*batch:(minibatch+1)*batch])\n",
    "        rng += 10\n",
    "        opt_state, params = update_derivative(fraction, opt_state, batch_data, 1e-6)\n",
    "        cur_loss = loss(params, batch_data, 0.0)\n",
    "        epoch_loss += cur_loss\n",
    "        num_samples += batch\n",
    "    closs = epoch_loss/num_samples\n",
    "    print('epoch={} lr={} loss={}'.format(\n",
    "        epoch, OneCycleLR(fraction), closs)\n",
    "         )\n",
    "    if closs < best_loss:\n",
    "        best_loss = closs\n",
    "        best_params = [[copy(jax.device_get(l2)) for l2 in l1] if len(l1) > 0 else () for l1 in params]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Look at distribution of weights to make a better model?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_params = pkl.load(\n",
    "    open('best_dblpendulum_baseline_v5_900epoch.pt', 'rb')\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "# p = get_params(opt_state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt_state = opt_init(best_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pkl.dump(\n",
    "#     best_params,\n",
    "#     open('best_dblpendulum_baseline_v5_900epoch.pt', 'wb')\n",
    "# )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Make sure the args are the same:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "# opt_state = opt_init(loaded['params'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeviceArray([7, 7], dtype=uint32)"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rng+7"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The seed: [8, 8] looks pretty good! Set args.n_updates=3, and the file params_for_loss_0.29429444670677185_nupdates=1.pkl."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rc('font', family='serif')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_t = 10\n",
    "new_dataset = new_get_dataset(jax.random.PRNGKey(2),\n",
    "                              t_span=[0, max_t],\n",
    "                              fps=10, test_split=1.0,\n",
    "                              unlimited_steps=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = new_dataset['x'][0, :]\n",
    "tall = [jax.device_get(t)]\n",
    "p = get_params(opt_state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(99, 4)"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new_dataset['x'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_tall = jax.device_get(odeint(\n",
    "    partial(baseline_eom, learned_dynamics(p)),\n",
    "    t,\n",
    "    np.linspace(0, max_t, num=new_dataset['x'].shape[0])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "@jit\n",
    "def kinetic_energy(state, m1=1, m2=1, l1=1, l2=1, g=9.8):\n",
    "    q, q_dot = jnp.split(state, 2)\n",
    "    (t1, t2), (w1, w2) = q, q_dot\n",
    "\n",
    "    T1 = 0.5 * m1 * (l1 * w1)**2\n",
    "    T2 = 0.5 * m2 * ((l1 * w1)**2 + (l2 * w2)**2 + 2 * l1 * l2 * w1 * w2 * jnp.cos(t1 - t2))\n",
    "    T = T1 + T2\n",
    "    return T\n",
    "\n",
    "@jit\n",
    "def potential_energy(state, m1=1, m2=1, l1=1, l2=1, g=9.8):\n",
    "    q, q_dot = jnp.split(state, 2)\n",
    "    (t1, t2), (w1, w2) = q, q_dot\n",
    "\n",
    "    y1 = -l1 * jnp.cos(t1)\n",
    "    y2 = y1 - l2 * jnp.cos(t2)\n",
    "    V = m1 * g * y1 + m2 * g * y2\n",
    "    return V"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(0.5, 0, 'Time')"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZAAAAEECAYAAAAGSGKZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dd3hUVf7H8feZmfTeCZBCAgk9lBBq6EhXsYIdV+zrqqvuuru2Xbvu6rrrqij2n2CviHSkl9BBQyCFQGiBkEJ6Zs7vj4y7USlJSHKnfF/PMw+ZkpnPfSbDd+6953yP0lojhBBCNJXJ6ABCCCGckxQQIYQQzSIFRAghRLNIARFCCNEsUkCEEEI0i8XoAK0tPDxcx8fHGx1DCCGcyubNm49rrSPO9hiXLyDx8fFkZGQYHUMIIZyKUmr/uR4jh7CEEEI0ixQQIYQQzSIFRAghRLM41DkQpVQ74HEgRWs94DT3m4AngVNAHDBHa72+bVMKIYQABysgwDDgS6DPGe6/AgjUWv9RKRUKrFdKddNaW9ssoRBCCMDBDmFprT8Bys7ykMnAOvtji4AqoEcbRBNCCPELDlVAGiGSnxeYUvttP6OUulkplaGUyigsLGyzcEII4U6crYAcAwIaXA+03/YzWuvZWutUrXVqRMRZ58EIIZpAa01O4Sk+2JDPuuwTRscRBnO0cyC/opTyA3y11oXAfGA48J79HIg3sNvIfEK4DJsV9nwLFSfAVld/3WalpKKS/ONlHDheRkFRGRXVNViw8q2tB7snXsFvhnVCKWV0emEAhyogSqkRwLVAtFLqL8DfgRuAXsCtwEdAX6XUI0AscJ2cQBeihax4GlY++6ubg6j/APb66QYLaGXiVv0NsxbUcqBoGg9P7YHZJEXE3ShXX5EwNTVVSysTIc5h/zr025NY7jmSB0suwYYJby9P+sSFMyAhgrTECJLahWAyW8Bkguoy9DtTqTv8A1dV/YGgriN4aUYffD0d6jupOA9Kqc1a69SzPkYKyOlVlpex6/WbIT6djv3HEx2T2ArphHAAlcXwajrHymuZVPMUN47uxZDEcHq2D8RiPstp0vIT8OZ4akqOcHHFn7C0780b16cSGeDddtlFq5ECQvMLSM6uDYR+cgnBnAIgX7WnIGQApoQRxPW7gHbtY1o6qhBtT2v49Cb07s+ZVvUI48dP4baRTfiyVHwA3hxPdU01U8sfosI/lrdnDqBzZMC5f1c4NCkgnN8hLJvVSt4PGzm+czHeB1eTWLEdP6oA2KfiORw6AEvicDr1v4B2Ue1aMLUQbWT7h/D5zbzpeTXveV7Bd3en42UxN+05CvfAmxOotvhxYcUjHLYGMvu6VAYlhLVOZtEmpIDQsudAbLU17N+1lhO7FuN7aC0JlTvxpharVmzxSqP99XPo0EH2TISTKMqFV9M57JPI0KP3MWfmQEYl/2paVeMUbIa3p1ITGMsVNQ+xu0jx/OUpXNSnQ8tmFm1GCgitexLdWlNJ/o6VFO9YQPf973NCBXHogtmkDhnTKq8nRIux1sFbE7EVZnJB5ZPEJ3bljet/1X6uabKXwwdXUNeuLzOtD7Iqr4L7xydz+8hEGebrhBpTQJxtIqFDMXv60Cl1PH1vfJHjV36NWSl6LbyShe8/j9Xm2oVZOLmVz8LBjbwfdjf5tnAemtL9/J8zcRRc8jqWgo287f8y01IieW7hHh78bKd8HlyUFJAW0qH7YALvWsN+/96M3/c3Vvz9GopKTxkdS4hfy18PK5/jeOI0Hs7pys3pCcSF+bXMc/e4GKa8gHnfYv7hOZs7RnZi3qYDvLM2r2WeXzgUKSAtyCckii73LuSHhJmMKf+GQy+MYtePPxodS4j/qSqBT2ehg2O55cQM2gd5c/uoFh6injoTRj+E2vkR99neJr1zGC8syeL4qeqWfR1hOCkgLUyZPeh+3YvsH/MKCTqfqHnj+e7bz3D1c03CScz/PZQWsDDpb2w+UsefJndrncl/6b+HQXegNr7Gi+2XUFlj5dnvMlv+dYShpIC0krj0q6i7cQlWjwDGbLiJT195iPKqWqNjCXe24yPY+TGVQ+7jj5u8GZwQxuRe0a3zWkrBBY9DygzCNj7H4z0P81HGQbYdKG6d1xOGkALSigJjexF57xoOhg/jsmP/Yt3fLyW7QNrLCwOczINv7oWYQTx1aiJlVXU8emGP1h0dZTLB1JcgpBOXn3yDKH8Lj3y5C5ucUHcZUkBamck3mE53fMH+lHsZXbuS2tlj2P6DnBcRbchaB5/OAqXIGvZ33t94iGsHxZHcrg1mi1s8YcxDmAt/4OWee9l+sIRPthxs/dcVbUIKSFswmYib9gjF0/6POHUU28czKSwpNzqVcBdrXoCDG9GT/86fl5cS4uvJPeOS2u71u0+D6D70z3mFQbF+PPtdJqVyONclSAFpQ6Epkyka9Qx99Y+sef1u6qw2oyMJV6c1bHoTOo/jS+tQNuWd5IEJyQT5eLRdBpMJxv0VVXqQFzpt5ER5DS8u3tt2ry9ajRSQNtZhxA3kxl7Kxac+4rOP3jY6jnB1R3ZC2SGqkqby5Lc/ktIxiMv7G9BuJ2EEJI4hevvLzOwXzDvr8sg6WnbOXxOOTQqIATpd+zKHvRMZm/kQ32/aanQc4cr2LgRg9qFEjpVV8+iFPTAZtfDTuMegqoT7/Rbg72Xh0a92y/B2JycFxAgePoTNnIu3yUrgN7eQd1SGNopWkrWIqsgU/rWplMv7d6RvbIhxWdr1gt5X4LN5Ng+lB7I2+wQLdh0xLo84b1JADOIZlUzl+H/QV+1h05v3UFkjK/OKFlZ+Ag5uYmldCt4WMw9M6Gp0Ihj1Z0BzScm7dG0XwBPzf5S/fScmBcRAYYOuoqDzDC6v/oz/e2+27M6LlrVvCaB553gyF/VtT0SAl9GJICQOBszCtGMuz6ZbKCiu5JXvs41OJZpJCojBOlz5Isf8krgs/298tXKD0XGEK9m7kBrvcDbVxDGsc4TRaf5n+H3g6U/vPf/kwpT2vPp9NgeKKoxOJZpBCojRPLwJmzkPL5Mmbukd7Nx/zOhEwhVY62DfErICBqGUicGJDrQ6oG8oDLsbsr7jkd7FWEyKv33zg9GpRDNIAXEA5vBE6qa+RB/TPna/ey/FFTVGRxLO7uBGqCphfnVvUmKC23beR2MMvA0Coglb9wR3jkpk0Q9HWZklbX6cjRQQBxHQ73KOd7uO6davefetl6VfkDg/WQvRJgsfFCaQ3jnc6DS/5ukLIx+Eg5uYFbGbTuF+PPr1bmrqZHKtM5EC4kDCL32eE4HduOHYs7y74Huj4whntncRRWH9KdG+DOviQOc/GupzNYQn47H8bzwyqQs5heW8vTbX6FSiCaSAOBKLF6HXf4CHWdF3wz2sziwwOpFwRsUH4NgPbLCk4udppm9ssNGJTs9sgbGPwIl9jKxYyIikCP6zIpuKmjqjk4lGkgLiYFRYAqaLXybFlEP+Rw9QVStj5EUT2Wefzy3uxqCEMDzMDvwxT54EMYNgxdPcPbw9xRW1fLjpgNGpRCM58F+W+/LqPY2jXWZwhfVbvly22ug4wtlkLaIuMI5VJ0MY1sUBz380pBSM+yucOkrfgrkMiA/hjVW51EqjUacgBcRBRV34KDaTBZ/1L1BeLbv0opFqKyF3JdkhQwFFuqMXEIDYgdB1Cqz5J3cNCqGguJKvtx8yOpVoBCkgjiqgHcXdrmaS7Xs+W7LS6DTCWeSugrpKFtWm0C7Qm8QIf6MTNc6Yh6G2nGGH3iI5KoBXv8+WkYhOQAqIA4uc+EdsJguBm16kpEIW4BGNsHch2sOXdw91ZFiX8NZdsrYlRSRDygzUlne5a0goWUdPsXyPTKp1dFJAHFlAO8p6XscUvZJPFq8wOo1wdFpD1iJKo4dSWOUkh68aGvJbqKtkQuW3dAj24ZUV0iPL0UkBcXBhFzyA1eRJ+JaXOHGq2ug4wpEVZkJJPlu80gAY6ogTCM8msht0Hot502xuHdKejP0n2ZRXZHQqcRZSQBxdQBTlva9nCqv4aOFyo9MIR5ZVP3z349KudI8OJNzfAbrvNtXgO6G8kCu91xPq58mrshfi0KSAOIGQcfdjNXkSvf3fHC2tMjqOcFR7F2GL7MnigxbnO3z1k4SRENULz02vcP2gOJZmHmPPEVn61lFJAXEG/pFU9ZnJVLWaeQuWGp1GOKLKk5C/nvzwdGqt2vHnf5yJUjDkTijM5MZ2+/D1NPOarBfisKSAOInAMb/HavIkfvfLsnaC+LXsZaCtLLf1wdNiYkB8qNGJmq/HJRAQTcCWV5g+IJYvtx/i4En5m3dEUkCchX8kNf1uZIpay7wFS4xOIxxN1iLwCeWjQ1GkxYfi7WE2OlHzWTxh4K2Qu5Jbk8tRwBurpMmiI5IC4kT8R/0eq9mLpMxXySk8ZXQc4ShsVti3mKq4Ufx4rMJ5D1811P8G8PQnctfrXNy3A/M25VNULuvkOBopIM7EP4K6fr9hqmktH8yXvRBhV7AFKk6ww3cgAMOcbfju6fgEQ99rYden3NHfm6paG2+vzTM6lfgFhyogSqmxSqn/KKUeVUo9cpr7b1BKrVdKrbBfrjUip5F8R91LndmLXtmvkXmk1Og4whHsXQjKxFenuhHm50n36ECjE7WMQbeCttFp3/uM6x7Fu+vypC+cg3GYAqKU8gVeBe7RWj8K9FZKjTnNQ6drrUfaL++1aUhH4BeOLXUWU03rmDd/kdFphCPIWoiOGcii3BqGdA7HZHKS9iXnEhIP3S+Cze9w+5AoiitqmSet3h2KwxQQYDCwX2v903TrNcDk0zzuTqXUfUqph5VSTjzUpPm8R9xNndmbfnlvsONgsdFxhJFKD8ORHRyPHsGxsmrHXL72fAz+LVSX0Lfwa9I6hfLGqhxZ9taBOFIBiQQazhgqtd/W0PfAM1rr54EM4OPTPZFS6malVIZSKqOwsLBVwhrKLxydNosp5vXM+0b2Qtza3vr3fyX9AFzjBHpDHftD7GBY/wq3p8dxuKSKr6TVu8NwpAJyDAhocD3Qftt/aa1ztdY/VYRlwAil1K/GK2qtZ2utU7XWqRERDroe9HnySr+bOrMPgwvmSL8gd7Z3EQR25OvDwSRE+NE+2MfoRC1vyG+hJJ8R1nV0bSet3h2JIxWQdUCcUuqnBj5DgflKqVClVCCAUuoppZTFfn8XIFdr7Z5rvvqFoQbewmTzBuZ9sxCt5QPlduqqIXs51s7j2JB70vUOX/0kaSKEJqLW/ZvbRiSw79gplmZKq3dH4DAFRGtdAdwGvKSUehzYobVeCvwRuN3+sCPAK0qpPwF/AtxuFFZDHsPuwmr2YdTRt9iYK3shbmf/GqgtJytoCJW1VoZ1cc29bUwmGHw7HNrC5OA8Oob48J8V++RLkwNwmAICoLVerLW+RWv9F631Y/bbHtBaP23/+Z9a61la6ye11jO01uuNTWww31DUwFuZYt7A+g2ydrrbyVoEFm8WnErCbFIMSnDhMSUpV4FPKJb1L3Pz8AS25heTsf+k0ancnkMVENF0lqF3YMOEf9aXWOW4sHvZuxDi01mRe4q+McEEeHsYnaj1ePrCgJtgzwIuj68iyMeDt9ZIexOjSQFxdn7hFIWnMtS6kc3yjcx9HN8HRTlUxI1hZ0GJ642+Op20WWD2xGfza0wfEMPC3UcpKK40OpVbkwLiAgJSLqKr6QDrMzYZHUW0lez6tv7rLf3RGudd/6Mp/CMh5UrY9gHXp/ihtea9dfuNTuXWpIC4AK+eUwBQmfPlxKK7yF4GoQksPuxNgJeFlI7BRidqG4PvhLoq2u+dywXd2zFvUz6VNe45ENMRSAFxBSHxFAcmM7B2PdsPlhidRrS2uhrIXYVOGM2qvccZlBiGxewmH+WIZOhyAWx6nRsHRVNcUcsX2wqMTuW23OSvzvV597yQVJXF91t3Gx1FtLYDG6C2nGNRQzl4stI9Dl81ZF83fUDZUrpFB/L2mjzZ8zaIFBAX4d3rQkxKU7XrW/kwubrsZWCysLw6GYB0V53/cSadhkNEN9Sm15k5JI49R8tYl33C6FRuSQqIq2jXi1M+7elXuZYfD5ed+/HCeWUvg45pLM+tpEOwD/FhvkYnaltKQdpNcHg7F0UcItTPk7dkrRBDSAFxFUph7jaZ4aadLN2ebXQa0VrKj8Ph7VgTRrE2+wTpXcJRykXatzdF7+ngFYjXlje5Ki2WJT8eJf+ErJve1qSAuBCfXhfipWop2vGd0VFEa8lZAWiy/NMoq6pzj/kfp+PlDykzYPfnXNfLB7NSvLsuz+hUbkcKiCuJHUK1JZCep1az75isme6SspeBTwjzj0dgNinSO7vZ+Y+GBtwE1hoi933IxF7RfJhxQFYsbGNSQFyJ2YKtywTGmLayaIes3OZytK4vIAkjWZJ5gtS4EIJ8Xbh9yblEJEHCKMh4ixsGdaSsqo7Pthw0OpVbkQLiYnx6X0iwKufA9mVGRxEt7diPUHaYk+3TyTxSxthuUUYnMl7aLCgtoF/lWlI6BvHW2jxZK6QNSQFxNYmjqTN5kXTyew4UyUlFl5Jd/6VgWXUPAEZ3++WCnW4oaQIExaA2vs4NQ+PJKSxn1b7jRqdyG1JAXI2nH7VxIxhn3sx3Ow8bnUa0pOylEJ7M1/tNdAr3IzHC3+hExjOZYcBvIG8Vk9uVEhHgJV1625AUEBfk0+tCOqrjZG5bY3QU0VJqK2H/Wmo71Q/fHd1V9j7+q+91YPbCc8scrh4Yy4o9heQUyiCStiAFxBUlT8SGidjC5RwpqTI6jWgJ+eugroqdXv2oqbMxRg5f/Y9fGPS8FLbP45o+oXiYFe/IxMI2IQXEFfmFUx09gHGmzSzcfcToNKIlZC8DsyefFcUT4G1hQLwLrz7YHGmzoOYU4dmfMbV3ez7ZfJDSqlqjU7k8KSAuyqfXVLqb9pOxbavRUURL2LcMHTOIhXvLGJEUgYe7dN9trA79oEN/2DibmUPiKa+x8nGGDOltbfJX6KqSJwEQeWgpJ05VGxxGnJeyI3BsN4cjhlJYVi2Hr84k7WY4sZdeNVvpHxfCO2vzZJnnViYFxFWFJVIVksw4UwaLfjhqdBpxPrKXA7C0pgcmBSOTpICcVveLwTccNr3BzKHx5BdVsDzzmNGpXJoUEBfm1XMqA0x7WLV9j9FRxPnIXgp+EXyYH0j/uBBC/DyNTuSYPLyh33Ww51vGd6ihXaA3b8vJ9FYlBcSFqW5TMGPDf/8SSirkhKJTstkgezmVMSPYdfgUo7vK7POzSr0RAI+tb3Pt4DhW7ztO1lFZ3qC1SAFxZdF9qPGLZozKYMmPchjLKR3dCRXH2erZF4Cxcv7j7IJj6s//bXmXGf0i8bKYZC+kFUkBcWVK4dF9CiPMO1i6I8/oNKI59i0F4JOizsSE+tA5Umafn1PaLKg4QWjufKamtOfLrQXSpbeVSAFxcarrZLypQecs55R8iJxP9jJskT2Yn6cZ0zXKPRePaqpOIyA8CTbOZkZaDOU1Vr7ZccjoVC5JCoirix9GnUcAo/VGGZHibGrKIX89B0IHUy2zzxtPqfohvYe20M+cQ+dIf+ZulOUNWoMUEFdn9sDUdSLjLFtZtLPA6DSiKfLWgK2WxdU98PM0M7BTmNGJnEfvK8HTH7XpDaYPiGHbgWIyj5QancrlSAFxA6aukwimjOKsVVTVWo2OIxoreyna4sM7B9sxPCkCT4t8XBvNO7B+ydtdn3JpV288zSbmyV5Ii5O/SHfQeSw2kwcjbBtZmVVodBrRWNnLONVuIAfKNGNk8aimS5sF1hpCMudxQY8oPt9aIF+gWpgUEHfgFQAJI5lgkTVCnEbxATiexVaPvigFI5PdeO3z5opIhvh02Pw201NjKKmsleaiLUwKiJswdZtCR45xaO9mtJb+QA7Pvvrghye70DcmmHB/L4MDOam+10DxfoZ4ZBET6iOHsVqYFBB3kTQRjSKtah0HT1YanUacS/YyrP7RzD8SJIevzke3qeDhh2nHXK5MjWFdzgnyjpcbncplSAFxFwFRVEb2Y4x5C1vyTxqdRpyNzQo5K9gfPBBQMnz3fHj6QfeLYPcXXJYSjknBhxmyF9JSpIC4Ee8uI+ih8tiZI5OqHNqhrVBVzJKaHnQI9iE5KsDoRM6tzwyoKaPdoaWM7hrJJ5sPUmu1GZ3KJUgBcSOm2IFYlI1TeZuNjiLOJnsZGsVbh+MZ0y1SZp+fr7hhEBQL2z7gygGxFJZVs0wm1bYIKSDupOMAAEKLtlFZI8MZHVb2MspCe3K41o/RXeXw1XkzmSDlSshZzqj2dUQGePHhJjmM1RKkgLgTvzAq/OPoo7LYWVBidBpxOlWlcGAjWzz64OtpZlCCzD5vESkzQNuw7PqYy1M7smLPMQ6XyGCS8yUFxM2YY9Poa9rHlv1FRkcRp5O7ErSVj4qSGNY5HG8Ps9GJXENYIsQMhO1zubJ/DDaNrJneAhyqgCilxiql/qOUelQp9chp7vdWSv1bKfWgUupNpVSSETmdmVf8QCJUCfuzfzQ6ijid7KVYPfxYXBbHWBm+27JSpkNhJrE1WQztHMaHmw5gkzXTz4vDFBCllC/wKnCP1vpRoLdSaswvHnY3kK+1fgp4AZjTtildQEwaAKaCDJlQ6GhqK2H35+QGDaIWCyO7yuzzFtXjEjB7wba5TB8QS0FxJav3HTc6lVNzmAICDAb2a62r7dfXAJN/8ZjJwDoArfVOIEUpFfjLJ1JK3ayUylBKZRQWSu+nn4nsQZ3Zm841P8qEQkez8xOoPMk7dWNJiQkmMsDb6ESuxScYuk6CnR9zQdcQQnw9mLcp3+hUTs2RCkgk0HDx4lL7bU19DFrr2VrrVK11akSEfIv7GbOF6sg+9DPtlQmFjkRr2PgadeHdeP9oLGNk9FXrSLkKKovwylnKJf06sviHoxw/VX3u3xOn5UgF5BjQcMZUoP22pj5GnINPp0F0V/vZkSuN5RzGgQ1wZCfboy9Ha5l93moSR4N/VP3J9AEx1Fo1n22Rk+nN5UgFZB0Qp5T6qWvcUGC+Uiq0wWGq+dQf6kIp1QvYrrWWVWKayBQ7EA9l5VRehtFRxE82vAbeQcwpTaNdoDfdo391ZFa0BLMFel0OWQtJ8q+hX2ww8zYdkPOBzeQwBURrXQHcBryklHoc2KG1Xgr8Ebjd/rB/Ul9k/gL8HviNIWGdnX1CYciJbbI+giMoPQw/fkVp1+ksyCrlkn4dZPZ5a0qZAbZa2PUJ09NiySksZ1OeHM5tjvMuIEqpP7REEACt9WKt9S1a679orR+z3/aA1vpp+8+VWus7tNaPa61v0FpntdRruxX/CCr8YkhRe9lxUCYUGi7jTbBZebNmDBaT4oYh8UYncm3tekK7XrB9LlN6R+PvZZGT6c3U5AKilPqoweVj4KZWyCVamTk2rf5EukwoNFZdNWx+i5rEcby2U3NRnw5EBsroq1aXchUc2opv8T4u7NOeb3cepqSy1uhUTqc5eyClWusr7JfLgSUtHUq0Pq/4QUSpYvKy9xgdxb3t/gLKC1ngcyGVtVZmpScYncg99LocTBbY/gHTB8RQVWvjq20FRqdyOs0pIE/84vqfWyKIaGMx9edBlEwoNNbG17CFdeHxH6MYkRRBcjtp3d4m/COg8zjY8RG9ov3pHh3I3I1yMr2pzllAlFLxSqnnlFKfKaXeACYrpeJ+ul9rLcdAnFFUT+pM3nSp+UEmFBrl4GYo2My2qMsoLK+VvY+21mcGlB1G5a5geloMPxwuZfchGdTZFI3ZA/kSyAReBsYBKcBKpdTLDYbcCmdj9qA6snd9Y0WZUGiMja+hPf157EAK3aIDGdpZOu+2qaQJ4B0M2+YyqVc0SsHiH44ancqpNKaAmLXWc+xDaou01rOARCAPmN2a4UTr8uk0iJ4ql5158qFpc6eOwe7PKYibxvZCG7PSO8nQ3bZm8YKel0LmN4RbqukbE8zSTPksNEVjCsgSpdSd9p81gNa6Tmv9HPZJfcI5mWLT8FBWSnNlQmGb2/wOWGt4oWQE7QK9mdK7vdGJ3FOfq6CuCn74gjHdothVUMqRkiqjUzmNxhSQe4EgpVQG0N7eqPAapdTLwInWjSdaVcf6zrzBJ7bLhMK2ZK2FjDmUdRjOp/m+3DA0Hk+Lw8zpdS8d+kNYF9g297/t82UvpPHO+VertbZprZ8AhgM3A+2A/sAuYGLrxhOtKiCKSr+OpKgsmVDYln78GsoOM1dNwM/TzIy0WKMTuS+l6k+m568lyfM4HUN8WPqjtNdrrEZ/7dFaV2itv9Ja/1VrfY/W+hWtdXFrhhOtzxQzgH5yIr1tbZxNXVAcz+XEMT0tliAfD6MTubfeVwIKtX0eY7tFsWbfcSprZI+8MWS/2c15xQ8iWhWRky1dYdrE4R2Qv44VgRdhw8TMofFGJxJBHaHTcNg+lzFdw6mus8lCU40kBcTd2ScUmgo2ySSqtrDxNbSHL4/k92FSr2g6hvganUhA/cn04v0M8szF38vC0h/lPEhjSAFxd1G9qDN5kVgtKxS2uooi2PkJmRETKaj2ZlZ6J6MTiZ90uQCUCY/c5YxIimBp5jFZL70RpIC4O4sn1RG9ZIXCtrDlHair4onCYQzsFErvjsFGJxI/8Q2F9v0gexljukVSWFbNzgIZWHIuUkAEPgmD6KVy2ZEno09ajbUONs2hMDyN1WVR3Dxc2pY4nMRRULCZ0XFemBRyGKsRpIAITDFpeKo6SnI3Gx3FdWUtgJIDzK4aS0KEH6OSZclah5M4GrSV4GPr6R8XwhIZzntOUkBEgwmFskJhq9k4m2rfaN483o1Z6QmYTNK2xOF0HACe/vbDWFH8cLiUQ8VyXvBspIAICIym0rc9fWSFwtZxZCfkruQrz0kE+/kwrW8HoxOJ0zF71A/nzV7G2G71e4hyGOvspIAIoH5CYR+ZUNg6ljyG1TOIvx0ZyHWD4/H2MBudSJxJ4mg4mUeiuZC4MF85jHUOUkAEUD+hsKM6Tk7OXqOjuJbclbBvMd+FXEW1JZBrBknbEoeWOBoAlbOMMV2jWJd9gvLqOoNDOS4pIKJex/oJhRyUCbHM4CAAABnFSURBVIUtxmaDxQ9jDejAHw4O5rL+HQnzlyV0HFpoAgTHQvZyxnaLpMZqY9VemZV+JlJARL3o3liVh0wobEk/fA6HtvJZ0PVUaQ9ZcdAZKAUJoyB3JQPiAgnwllnpZyMFRNSzeFEd0VsmFLaUuhpY+leqQrvyYHZ3rh0cR3y4n9GpRGMkjobqUjwOb2VEUgTL98is9DORAiL+y7vTQPuEwkKjozi/zW/ByTz+ZboGP28vfjemi9GJRGN1Gg7KBNnLGNc9iuOnath2UBqPn44UEPFfptg0vFUtJ3O3GB3FuVWVwvfPcDJyEC8f7MRdY7oQ7OtpdCrRWA3amoxMisRsUnIY6wykgIj/kQmFLWPtS1Bxgj+XX058mB/XDoozOpFoqsTRUJBBkConNS5EFpk6Aykg4n+COlDlE0WK2iuN5Jqr7Aise5m8dhP49kQ0f5zYTZardUaJo0HbIHclY7tFkXmkjIMnK4xO5XDkL1v8jIpJo5/ay5b9ciK9WVY8jbbW8rtjU0jrFMr4HlFGJxLN0TEVPAP+250XkL2Q05ACIn7GK34QMaZC9uVkGx3F+RzfC1veZXPENLZXhPLQ5O4oJT2vnJLZAzqlQ85yEiL8SQj3Y4mcB/kVKSDi5+wTCrVMKGy6JY9is3hz58ExXNKvA706BhmdSJwPe1sTinIY0y2SDTlFnJJZ6T8jBUT8XHQKVmWhs0wobJr8DZD5Dd8GXkmxKYj7xycbnUicL3tbk5+689ZYbazKkiHuDUkBET/n4U1NRC/6mvayZp+0cGgUrWHxw9T6RHB/wTBuTk8gOsjH6FTifDVoa5IaF0KQj4c0V/wFKSDiV7w7DSTFlMM3W/ONjuIc9nwLB9bzhuVK/AOCuGVEotGJREtQqn4vJHclFqyMTK6flW6VWen/JQVE/IqKScObGkr2b5UFdc7FWgdLHuWUfyeeLxzI/Rck4+dlMTqVaCn2tiYUbGZMtyiKymvYdkBGKP5ECoj4tbhhaGVinCmDr7YfMjqNY9v2PhzP4smaK+jSLphL+3c0OpFoSQ3amoxIisBiUnIYqwEpIOLXAqJQCaOY4bmGL7ccMDqN46oph+VPcSQwhQ9Ke/OXyd0xy1K1rsUnxN7WZDlBPh6kdQqVtiYNSAERp9fnKiJshQQXbuTHw6VGp3FM6/8Dp45wf8mljOkaxbAu4UYnEq3B3taEymLGdIsi6+gp8k/IrHSQAiLOpOtkbF6BXG5eyRdbC4xO43hO7odV/+CHoOGsre3Cg5O6GZ1ItJaftTWpn5UukwrrOUQBUUqFKqVmK6X+qJSao5Q6bf8HpVSeUmqF/fJ/bZ3TrXj4YOoxjUmWTSzami0jTxrSGr69H5uGmwuv4JqBsXSO9Dc6lWgtDdqaxIX50TnSn+V75DwIOEgBAZ4Elmitnwa+AJ4/w+Pe1lqPtF+ubrt4bqrP1XjrKvpXrGRDzgmj0ziOH7+CvQuZ638tJZ6R/G5sktGJRGsye9SfTM9eClozIimCDblF0rEaxykgk4F19p/X2K+fznCl1ANKqb8ppYa0TTQ3FpOGLTSRKyyr+VwOY9WrKoUFf6A4qCsPHxnG78clEeona324vMRRUJwPRTmkdwmnps7Gxtwio1MZrs0KiFJqoVJq22kuFwKRQJn9oaVAiFLqdIPp/6i1fhZ4CnhTKdX5DK91s1IqQymVUVgorQeaTSlMKTNIU7vZvmuHfOMCWPY4uuwId5ZeR0psGNcOjjc6kWgLDdqaDOwUhqfZxKq98n9LmxUQrfV4rXWf01y+Ao4BAfaHBgIntda/6lqmtd5o/7cC2AYMPcNrzdZap2qtUyMiIlpng9xFynQ0igl1K+TEYcFm2DibVcEXs6GmE89c2luG7bqLn9qa5KzAx9NManwIq/ZKqx9HOYQ1Hxhs/3mo/TpKKZNSKtb+8xil1IQGv9MZkJ7jrS04BuLTucJjFV9sOWh0GuNY6+Dru6n2juCOI5O5Y1RnukQFnPv3hGto0NYEay3pXSLIPFLGsdIqo5MZylEKyJ+AcUqpvwCXAPfZb++NvZhQv5cySyn1J6XUv4FPtdar2z6q+1F9r6YjRzm1dzVF5TVGxzHGxtfgyA7+Wncd7SIjuW2k9LtyOw3amqTb5/ysdvOGow5RQLTWRVrrWVrrx7XWM7XWR+23b9Na97L/vFNrfanW+kmt9Z1a66eMTe1Guk3F6uHHxep75u9ww9YmJQdh2RNkBQ7mg/K+PH1pb7wsZqNTibbWoK1J9+hAwvw83f4wlkMUEOHgPP0w97iYqZYNzN/ihkcNF/wBm83KjYXTuX5wJ/rHhRidSBjBJwQ69IfsZZhMimFdwlm19zg2N54jJQVENE6fq/CjkqiCJew/UW50mraTOR8yv2GO5UpsgTHcJwtFubeEUfWDKSpPMqxzOMdPVZN5pOzcv+eipICIxokdQl1gDJdZVvHFVjc5jFVdBt/ez3HfzjxTMoYnpvXCX1q1u7cGbU3Su9SP8Fy9z32H80oBEY1jMmHpexVDTbtYvWW7e6yXvvwpKC3gttJrmdwnllFdI41OJIzWoK1JuyBvkqL83fo8iBQQ0Xgp0zGhGVCyiO0HS4xO07oOb0dveIWF3hPZ59mdh6d0NzqRcARmD0gYAVmLQGvSu7h3WxMpIKLxQhOo6zio/jCWK88JsVnh699R5RHC/cXTeGhKd8L8vYxOJRxF8iQoOwSHt7t9WxMpIKJJLP2uJkEdInf799RabUbHaR2b5sChrTxcdTUpXeKZ1reD0YmEI0kaXz+cd88Ct29rIgVENE33i7GavRlXs5TVrnjst+Qgeulf2eXdn2/0EJ6c1gulpF2JaMAvHGIGwp75bt/WRAqIaBrvQOg2lQst6/l6c47RaVqWtQ4++Q1Wm5XbS67hvvFdiQn1NTqVcETJE+HITig+4NZtTaSAiCYz972aQMqxZi7gVPWvel46rxVPwYH1PGKdRUiHJG4YEm90IuGokifV/5v1nVu3NZECIpqu03BqfKO5iBUs3HXE6DQtI3s5etXfWeE7no9rB/PMZdJpV5xFeBcI6wKZ8926rYkUENF0JjMe/WYw3LyD5Zt3GJ3m/J06Bp/dTJFPPLcVXcnjF/Wka7tAo1MJR5c8EfJWY6opZWjn+rYmbjE/qgEpIKJZVJ+rsGCjff7XHHXmY782G3x2M9aqEmYU38rU1M5cMSDG6FTCGSRPAlst7FtKehf3bGsiBUQ0T3gXqqL6cYlpFV9vc+Llbte8ADnLedx6A5Z2PfjrRT2NTiScRUwa+IbBnm//29bE3YbzSgERzeadeg1dTQfYsHa5c87EzV+PXvYEKz3T+YTRvHpNf7w9pE27aCSTGZImwN5FtPM3u2VbEykgovl6XoLN5MmQU4t4ekGm0WmapqIIPvkNRR5R3F56Pf+4oi+xYTJkVzRR8kSoKoH8dQzr7H5tTaSAiObzCcHUcxrXeixnw7rvnWf3XWv48g5sZUe4oex2rhvZi3Hdo4xOJZxR4mgwe8GeBaQnuV9bEykg4vxc8ARm3xBe9XmZhz7aSElFrdGJzm3Da7DnW56uuwr/TgO4d1yS0YmEs/L0g4SRkDmfgfEheJpNbjUfRAqIOD/+EahLZxNrK+D2qtd56MtdRic6u0Nb0YsfYo15AF96X8hLM/piMcvHQJyH5IlQvB/f4r2kxoewMstJ9sRbgHxyxPlLGIkadg9XmJdj2/kpXzrqqKyqUvTHMzlJEL+rnMXLV/cnIkC67IrzlDyx/l/7aCx3amsiBUS0jFF/QncYwDNec3j1i2UcLqk0OtHPaQ3f3IM+uZ9bKm7jtklppMaHGp1KuIKAdvVrpe/51u3amkgBES3D7IG6bA4+Hhae1i/y4MdbsNkcaFbu5rdg1yf8o+4yInuO4sah8UYnEq4keSIUbKa7fwWhbtTWRAqIaDkhcZgueokUtY+Bea/w7ro8oxPV2zQH/c29rFV9+C54Os9c1ltatIuWlTwZANO+hQxzo7YmUkBEy+oxDd3vBm6zfM33333EvmMGt3ZY9Q+Yfy8bPVK5ve73/OfaNPy9LMZmEq4nshsEx0Hmt27V1kQKiGhxasJT1IUl85z5ZR6bu8KYlQu1hsUPw9LHWGhK55aae3jp2sEkRQW0fRbh+pSq742Vs4Lh8fUTUp1mXtR5kAIiWp6nL5Yr3ibEVMVNx5/lX0v2tO3r26zwzd2w5p/M0+N4zPI75t6azvCkiLbNIdxL10lgrSaqcB1dIt2jrYkUENE6orpjnvgUI8w7qFr1ElvyT7bN69bVwKc3wea3+U/dhbwXehef35lOt2hpzy5aWexg8A6qn5XexT3amkgBEa0n9UZqk6fygOVDXvvgY8pbe/XCmgr0vKtg92c8WTuDzZ3v4qNbhxAV6N26rysEgNkDulxQv0ph5xBq6mxsynPttiZSQETrUQqPi/+F1S+KP1U+x/NfZ7Tea1WVYH13GnrfEv5YexPVaXcy+7pU/OSEuWhLyROh4jiDPXPwNJtc/jCWFBDRunxC8LryLWLUCVK2P8aDn+7gSEkLz9I9VUjtm5OxHczgrtrfkjzpTh67qKcsSSvaXuexYPLAO2ehW7Q1kQIiWl/sIGwj/sDF5rVM2z6Lx55/jie+2UVRec35P3fxAWpevwDrsT3cbr2fi66+k5lDO53/8wrRHN5BED8MMr9lWJdwMo+UOfeKnecgBUS0CcuI+2D8U/QLKuMV8/PM2HgZ/372Qf713Q7KqprRwfdENqx+kerZY6kuOcyd5oe565bbpC27MF7yJDixl6ntT2FS8PrKHKMTtRrl6rMlU1NTdUZGKx57F01jrYUfvqRq5T/xLtxBkfbnEzUen2G3cvmIs6wIqDUc+4GK7Z9Tt/srAkvqhwZvtyXwWuBd/Pmm6XQI9mnDDRHiDIoPwIs9YdzfeODwCL7Yeohl942gY4hzLVimlNqstU4962OkgAhDaA3711C67AUC85dQrT34zjwc05A7mTBqJB5mE2hN4Z61FG/+jJD93xFecxCbVmzSySwljaPtx9K5S3euHxpPoLeH0VskxP+8Ogw8/Tl0yeeMen4Fk3tH848r+hidqkkaU0BkiIowhlIQP4zAG4fB8b0UL/o7E7I+xWv1Utav7UdtUBxJxSuJ0icI1mY20pP54ZdB18n0Su7Cfe2D8LTIEVjhoJInwcrnaO9Rzg1D45m9ModZ6QkuNx9J9kCEw9CnCsn77iVCdr+Dt64k028ApfETiUi9kKS4WBlVJZzHoW0wewRc/AolSZeT/uwy+seF8NbMNKOTNZrsgQinovwj6HTZ32Daw2Cz0sdDJgAKJxWdAgHtIXM+QX2u4o5RnXlqQSbrsk8wODHM6HQtRo4BCMdj9gApHsKZKVU/qTB7GdRWcf2QeKKDvHn6u0yXavMuBUQIIVpD94ugtgLWv4y3h5l7xiWx/UAxC3YdMTpZi3GIAqKUMimlblFKHVNK9TzL48Yqpf6jlHpUKfVIW2YUQogm6TQcekyD5U9CwRYu7deRpCh/nlu4x5glDlqBQxQQIAXYAFSc6QFKKV/gVeAerfWjQG+l1Ji2iSeEEE2kFEx5Afyj4LNZmOsq+MOEruQeL+fDTQeMTtciHKKAaK23aq23neNhg4H9Wutq+/U1wOTWTSaEEOfBJwSmvVrfOWHhnxndNZK0+FBeXLK39btT52+AwqxWfYk2KyBKqYVKqW2nuVzYyKeIBBquEVlqv+10r3WzUipDKZVRWOjazcyEEA6u03AY8lvY/BZqzwL+MLErx09V8+bq3NZ7zSM74f8uh6/vqp+020rarIBorcdrrfuc5vJVI5/iGNBwPdJA+22ne63ZWutUrXVqRISsQieEMNjov0C7XvDVnfQPrWF8jyheW5nDiVPV5/7dpirKgfcuAS9/uOT1+kNprcQhDmGdjVLqp9aq64A4pZSX/fpQYL4xqYQQogksXnDpHKgphy9v5/4LkqmoqePfy/e17OuUHoZ3LwZbHVz7OQTHtOzz/4JDFBClVIhS6i9AEHCzUmqQ/fYIYLVSyltrXQHcBryklHoc2KG1XmpcaiGEaIKIZLjgcdi3hM55c7lyQAzvr99P/okzjh1qmsqT8P4lUHECrvkEIpJbfc6JtDIRQoi2onX9uYm8VRRetYj0tw4xvkc7/jm97/k9b005vDcNDm2Fqz+mKiadx77eTYivJw9M6Nqsp2xMKxOH2AMRQgi3oBRc9DJ4+hGx8A5mDe7Al9sOsaugpPnPWVcDH10HBzfBpXPICUhl2n/WMnfjAZSiVfdCpIAIIURbCoiqLyJHd3InHxLs68Ez32U277lsNvjiVti3BKa8yNe1qUz912oOl1Ty1g0DuH98V5Q7n0QXQgiXkzwR+s/Ea+PLPJFyklV7j7NqbxOnHGgNCx6AXZ9SO+oR/nKgH7+du5XkdgF8e1c6o7qedpZDi5ICIoQQRhj/BIQlMmnfo3QLtnLj25u4a+5WMvKKGnfYacVTsOl1SvrdxrQdA3h/fT43D0/gw1sG076NVueUAiKEEEbw9INL30CVH+PTmI+5ZmAsy/cc47JX1zH5pdXM3ZhPRc0ZZquvfxW+f4aD8ZcybPNIDhRV8sZ1qfxpUrf61TzbiIzCEkIII636Oyz9K6T+hprgBLYctbI4u4I9xQrt6c/QHglMGZBMbHRUfdHZ+TF8Nosfg4cz5cgsesaE8e8ZfYkJbdk112VBKSGEcHRD74YDmyBjDp7AIPsFT/v9u+0XQCsTStvY4dGby4/cyHVDE3lwYjfDlneWAiKEEEYymeGqeWCtheqyX11Kik+QkZXP7tyDUF1KnbbwEVP45zWDmNAz2tDoUkCEEMIRmD3AN7T+0kAQMCYNhlttLP7hKBl5J/lwSBxxYX7G5GxACogQQjgBD7OJSb2imdTL2L2OhmQUlhBCiGaRAiKEEKJZpIAIIYRoFikgQgghmkUKiBBCiGaRAiKEEKJZpIAIIYRoFikgQgghmsXlmykqpQqB/efxFOHA8RaK4whcbXvA9bZJtsfxudo2nW574rTWEWf7JZcvIOdLKZVxro6UzsTVtgdcb5tkexyfq21Tc7dHDmEJIYRoFikgQgghmkUKyLnNNjpAC3O17QHX2ybZHsfnatvUrO2RcyBCCCGaRfZAhBBCNIsUECGEEM0iC0qdgVJqLHAJcAzQWuvHDI503pRS64Eq+1Wr1nqMkXmaSinVDngcSNFaD7Df5g08DxQAXYCntdZZxqVsmjNs0w3ArfzvvZqjtX7PmIRNo5RKpH57tgAdgRNa678qpUKBp4Ec6t+nP2mtjxqXtHHOsj2PAiMbPPQJrfXitk/YdEopE/A1sIH6ldcTgRsBH5r4HkkBOQ2llC/wKtBDa12tlPpUKTVGa73U6Gzn6Tut9aNGhzgPw4AvgT4NbrsbyNdaP6uU6gXMAdKNCNdMp9smgOla67y2j3PeQoF5WusvAZRSPyil5gOzgCVa64+UUlOpL/rXGpizsc60PWitRxoZ7Dyt01o/DqCU+pL6L8vpNPE9kkNYpzcY2K+1rrZfXwNMNjBPS+mllPqDUupRpZTTbY/W+hOg7Bc3TwbW2e/fCaQopQLbOltznWGbAO5USt2nlHrY/u3dKWitN/30n62dCSinwfuEE32ezrI9KKX+bH+P/mD/0ukUtNa2BsXDQv2e1R6a8R7JHsjpRfLzD3Wp/TZn94zWeqNSygysVEqVaa1XGh3qPJ3pvSo1Jk6L+B6Yr7UuVEpNAj4GnOpwI4BSahqwUGudqZRq+D6VAiFKKYvWus64hE3zi+35GMjTWpcrpW4H/gX8xtiETaOUGg/cA3yjtc5oznskeyCndwwIaHA90H6bU9Nab7T/awVWAaOMTdQiXO690lrnaq0L7VeXASPsRd9pKKVGUf/3dY/9pobvUyBw0smKx8+2R2u9W2tdbr97GTDaqGzNpbVeqLWeAHSyF8Emv0dSQE5vHRCnlPKyXx8KzDcwz3lTSnVVSjX8htQF2GdUnhY0n/pDjtjPgWzXWjvz3gdKqafshxag/n3KtRd9p2A/PDoe+B3QTik1mAbvE072eTrd9iilnmvwEKf6LCmluv/iEHYukEAz3iOZSHgGSqlxwGVAIVDr7KOwlFLtgZepH00SCHgA92qtbYYGawKl1AjgOmAC8Arwd/tdzwOHgc7Ak042Cut023Qz0JP6D3Yv4J9a6/WGhWwCpVR/6g/BZdhv8qP+7+4r4BnqO2MnAn90klFYZ9qeZMCX+m/tvYCHneXvzj6y7Dnq/y/wALoBdwE1NPE9kgIihBCiWeQQlhBCiGaRAiKEEKJZpIAIIYRoFikgQgghmkUKiBBCiGaRmehCtBCl1CrqG9SFUd9b6HX7XR2oH/E43ahsQrQGGcYrRAtRSs3UWr+llOpJfXuI+J9uB97W8mETLkYOYQnRQrTWb53hrgDqJwWilJqplDqilLpfKfWeUmqBUuoKpdQcpdTKnxpBKqV6KKXetT9ujlIqoa22Q4jGkgIiRCvTWr/U4Oe3gExgi9b6WqAaCNBa/wbYCoyzP/QN4FWt9XPAe/xv1r0QDkPOgQhhjGz7v8UNfj7J/5rZ9QYuUEoNp36hn1NtG0+Ic5MCIoRj2g58prXeYW/qOc3oQEL8khQQIVqQUsqH+maIQUqpG7XWb9pbZQcppWYAx4E44Aal1FfU72lcq5Q6BAynftGvBdSvLfF7pVQuEAO8b8T2CHE2MgpLCCFEs8hJdCGEEM0iBUQIIUSzSAERQgjRLFJAhBBCNIsUECGEEM0iBUQIIUSzSAERQgjRLP8P/a6EBftwaG4AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "tall = np.array(tall)\n",
    "plt.plot(new_dataset['x'][:30, 0])#[:100, 0])\n",
    "plt.plot(pred_tall[:30, 0])\n",
    "\n",
    "plt.ylabel(r'$\\theta_1$')\n",
    "plt.xlabel('Time')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ba2d2e69723d4378bfe552e0ed8fcd35",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=0.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 current error 0.0038181357\n",
      "1 current error 0.021209426\n",
      "2 current error 0.01846676\n",
      "3 current error 0.047700986\n",
      "4 current error 0.050402828\n",
      "5 current error 0.046946183\n",
      "6 current error 0.044238884\n",
      "7 current error 0.039473616\n",
      "8 current error 0.035828345\n",
      "9 current error 0.032732606\n",
      "10 current error 0.030316547\n",
      "11 current error 0.028577417\n",
      "12 current error 0.028252253\n",
      "13 current error 0.029273892\n",
      "14 current error 0.031878117\n",
      "15 current error 0.029994529\n",
      "16 current error 0.03498629\n",
      "17 current error 0.03309936\n",
      "18 current error 0.03261298\n",
      "19 current error 0.03116028\n",
      "20 current error 0.030652355\n",
      "21 current error 0.029808465\n",
      "22 current error 0.03206598\n",
      "23 current error 0.030886434\n",
      "24 current error 0.02967531\n",
      "25 current error 0.028906314\n",
      "26 current error 0.02900661\n",
      "27 current error 0.028143957\n",
      "28 current error 0.027396614\n",
      "29 current error 0.026806256\n",
      "30 current error 0.026434598\n",
      "31 current error 0.025650725\n",
      "32 current error 0.024904592\n",
      "33 current error 0.02522585\n",
      "34 current error 0.089925386\n",
      "35 current error 0.08756612\n",
      "36 current error 0.08539982\n",
      "37 current error 0.08411583\n",
      "38 current error 0.08209654\n",
      "39 current error 0.08006577\n",
      "40 current error 0.07837987\n",
      "41 current error 0.07682229\n",
      "42 current error 0.07506938\n",
      "43 current error 0.07363075\n"
     ]
    }
   ],
   "source": [
    "all_errors = []\n",
    "for i in tqdm(range(100)):\n",
    "    max_t = 100\n",
    "    new_dataset = new_get_dataset(jax.random.PRNGKey(i),\n",
    "                                  t_span=[0, max_t],\n",
    "                                  fps=10, test_split=1.0,\n",
    "                                  unlimited_steps=False)\n",
    "    t = new_dataset['x'][0, :]\n",
    "    tall = [jax.device_get(t)]\n",
    "    p = best_params\n",
    "    pred_tall = jax.device_get(odeint(\n",
    "        partial(baseline_eom, learned_dynamics(p)),\n",
    "        t,\n",
    "        np.linspace(0, max_t, num=new_dataset['x'].shape[0]),\n",
    "        mxsteps=100))\n",
    "\n",
    "    total_true_energy = (\n",
    "        jax.vmap(kinetic_energy, 0, 0)(new_dataset['x'][:]) + \\\n",
    "        jax.vmap(potential_energy, 0, 0)(new_dataset['x'][:])\n",
    "    )\n",
    "    total_predicted_energy = (\n",
    "        jax.vmap(kinetic_energy, 0, 0)(pred_tall[:]) + \\\n",
    "        jax.vmap(potential_energy, 0, 0)(pred_tall[:])\n",
    "    )\n",
    "\n",
    "    scale=29.4\n",
    "\n",
    "    # translation = jnp.min(total_true_energy) + 1\n",
    "    # total_true_energy -= translation\n",
    "    # total_predicted_energy -= translation\n",
    "\n",
    "    cur_error = jnp.abs((total_predicted_energy-total_true_energy)[-1])/scale\n",
    "    all_errors.append(cur_error)\n",
    "    \n",
    "    print(i, 'current error', jnp.average(all_errors))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Can't run the 44th system. Just freezes..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make plots:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(-0.06, 0.01)"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZAAAAEfCAYAAABvWZDBAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dd3hUZfbA8e9J6L0XwdCLVIGIYkNEQUV3V90fuxawLq7dVWEBFRsKKvaKBXVxdW1rRUQpShGkLFVa6EUgdAiBhITz++PemUySmclNmUkmOZ/nmYd73/vemXNhyMm9bxNVxRhjjMmvuOIOwBhjTGyyBGKMMaZALIEYY4wpEEsgxhhjCsQSiDHGmAKxBGKMMaZALIEYY4wpkDwTiIj8IRqBGGOMiS2S10BCEZkP/AT8S1VXRCMoY4wxJZ+XBHIKsBm4HugIfK+q30Q+NGOMMSVZOQ914oFMIA04E2gmIv2AWar6SSSDM8YYU3J5aUT/AFgJdAEGquqlqnon0CmikRljjCnRvNyBrAFuVtXDvgIRqQDUjlhUxhhjSjwvbSDXquoHUYrHGGNMjPCSQPYD+wOKFFgP3Gu9sowxpuzy0gbyBHAh0AroBzwN3ATcHsG4jDHGlHBeEkh1VV2vjnVAU1XdinMXYowxpozy0oh+qoj8H5AEtAW6iUgdoFtEIzPGGFOieWkDaY7z2Koj8BswHOfOpZmqTotwfMYYY0ooLwnkOeB9VV0anZCMMcbEAi9tIG2AZZEOxBhjTGzxkkDmAdV9OyJyT+TCMcYYEyu8PMLaDNQHdrlFNVS1bqQDM8YYU7J5uQP5SFWrqGoLVW0BDIt0UMYYY0q+PO9AAEQkDqgD7FUvJxhjjCn1vKxI2A/YAEwArhaRWyIelTHGmBLPyyOsy4D2wBxV/TfOlCbGGGPKOC8JZJuqHsOZRBHgQATjMcYYEyO8TGXSVkSGA+1F5A6gaYRjMsYYEwO8dOOtDozAWZFwCTBWVVMK/cEiFwBXAMmAquqjOY5XAsYB23EGM45V1bXusdbusQxV/XNhYzHGGJN/nnphZTtBpLuq/q9QHypSBWd0e0dVTRORz4HXAufWcu96Tqjq0yLS2T1+jnvsGqAq0M8SiDHGFA8vvbC6i8iLIjJBRCYAbxbB5/YCNqtqmrs/BxiQo84AYC6Aqi4HuopIDXf/30B6EcRhjDGmgLy0gbwCPAfsdvcHFcHnNgAOB+wfcsu81Dnk9UNEZAgwBKBq1ao92rdvX6BgjTGmrFq0aNEeVa0f7JiXBDJPVT/z7YjI6iKIKZmA+bWAGm5ZfuuEpapv4t4xJSYm6sKFC/MfqTHGlGHudFZBeenGW0FEnhCR60RkMM4dSWHNBZqJSEV3/yxgkojU8T2mAibhPOrCbQNZqqqe7z6MMcZElpcEchZOe0NzoAXOlCaFoqqpwK3ASyIyGljmNqAPB25zq72Ik2QeBO7DWYcdABH5I+4ARxGxubmMMaYYeOnGe4aqzgvYb6OqSRGPrIjZIyxjjMk/EVmkqonBjoW8AxGRu0SkV47k0Qe4KgIxGmOMiTHhHmG1BNaIyMMiMspdG30eUDHMOcYYY8qIcAlkv6ruA8YDTXDmxDoKpEYlMmOMMSVauASiAKq6EziiqhmB5cYYY8q2cONA+otINXf7HBF52t0+A3gysmEZY4wp6cIlkHTgiLv9bUD58ciFY4wxJlaESyDDVHVBzkIR6RHBeIwxxsSIkG0gwZKHW74ocuEYY4yJFV5GohtjjDG5WAIxxhhTIPlOICLSPxKBGGOMiS0hG9FFZAa5x3wIkAC0imRQxhhjSr5wvbDmA6/lKBOcWXSNMcaUcSETiKr+M1i5iHwcuXCMMcbEijxXJBSRpsBdQD2cO5DOQNCpfY0xxpQdXhrRx+KsIHgc+BBYHNGIjDHGxAQvCWSJqn4BbFDVH4EtEY7JGGNMDPCSQHqISDOgvohcC/SJcEzGGGNigJcE8gJQHXgduAJ4OaIRGWOMiQl5NqKr6q8Bu1eIyBkRjMcYY0yMCDeQ8HZVfVVEJuQ41AXrhWWMMWVeuDsQ39K1ArwXUD4oYtEYY4yJGeEGEr7rbo5S1a0AInIy8G6oc0ojVUUV4uKE3w8cpXqlclSvVL64wzLGmGLnpRH9hoDtFODGovhgEblARF4TkUdE5OEgxyuJyCsiMkJEJohI24Bj14rIsyLytIjcUhTxhPLsD2tpOfI71iWncObY6XR+5IdIfpwxxsSMcG0gvYHzgN4i4iuOA5oU9kNFpArwBtBRVdNE5HMR6auq0wKq3QNsUdWnRaQz8A7O2uxNgfuBbqqqIrJARKaralJh4wrmX3M3AfDitIi8vTHGxKxwdyAHgE3AQWCz+1qHM61JYfUCNqtqmrs/BxiQo84AnBHwqOpyoKuI1AD6A4tU1TdT8Fzg4iKIKahDxzIA+Gbp75H6CGOMiUnh2kCWAktF5DtV3e0rF5G6RfC5DYDDAfuH3DIvdbycC4CIDAGGACQkJBQu4gBH0zOpXCG+yN7PGGNikZc2kKMicrmIDBaRwTiPngorGWdwok8Nt8xLHS/nAqCqb6pqoqom1q9fv0CB/u+hC3OVJR8+VqD3MsaY0sRLApkE9AVauK86RfC5c4FmIlLR3T8LmCQiddzHVL7P7QXgtoEsVdVDwBSc6VV8DTO9gMlFEFNQdapWoH2j6tnK7vrPkkh9nDHGxIw8R6IDq1T1Dt+OiLQo7IeqaqqI3Aq8JCK7gWWqOk1Engb24cwA/CIwTkQeBFoDN7nnbhORccDzIpIJvB2pBnSf1TsPZ9tfuvUAW/elcnKdKpH8WGOMKdEkqy06RAWRocASYC3OErc3quojkQ+taCUmJurChQsLdO4Tk1by1qyNuco3jc3Z7m+MMaWLiCxS1aCzj3i5A7kXWB2wnwA8UgRxxYwHBnTg3Lb16dC4Bj1GT/WXT1u1i76nNCzGyIwxpvh4aQMZoap9fC8gogP3Sqpz2tSnbrWK2cpuen8hy7cdJD3jRDFFZYwxxSfPBKKq74lIFxHp6w7im5bXOWXJZa/M5uynphd3GMYYE3V5JhC3DeRFYDDQBqeBu8z66G+5Z7NPPpwWpKYxxpRuXh5hVXMfXa1U1RlAmf5p2atVXabd1ztX+YrtB4shGmOMKT5eEohvyLXm2C+zqlXM3ffg0pdnF0MkxhhTfLwkkEwR+R64VEQ+A45EOKYSr0H1ipzZKveMLs2HT2LSsh3FEJExxkSfl0b0h4HngK+B8ar6ZMSjKuFEhA+DtIUA3P7h/6IcjTHGFA8vjegDVfUHVR0H7BaR56IQV0y4vU8rbjm3Za7ypF2HaT58Es2HTyqGqIwxJjq8DCRs79tQ1SUicm0E44kpQ/s7fzWLtx5g/sZ9/vIBL2W1h5w4ocTFSa5zjTEm1oW8AxGRu0VkI/APEdkgIhtFJAlvSadMOXY8M9t+embWwMKDR49HOxxjjImKkAlEVV9U1RbAEFVtqaotVLWNqt4Txfhiwtmt64U8lpojuRhjTGnhpRH9U3f98vtE5PxoBBVr+p7irGd1WdeTch2bk7Qn2uEYY0xUeGlEfwhnQsVmwFB33wTo0awOP/zjXF7666k8clmHbMeGfb6smKIyxpjI8tKeUUFVL/HtiMiYCMYTs9o2dBadqlAu9zjLQ8eOU6NS+WiHZIwxEeVpIGGOfZt6NoyalXMnii6P/FAMkRhjTGR5SSAZIvK1iLwgIt8ARyMdVCwb0KUx57atz7s3nJatfO2uwyHOMMaY2JTnIyxVHS0i/YAuwCRV/THyYcW2f93YM1dZv+dn8q8be1KnagU6NalZDFEZY0zRCplARKQecD+wB3hBVe05TD7FCZwIWDF48IT5gC2Fa4wpHcI9wnoDOIazBsjw6IRTunz69zODln+5eHuUIzHGmKIXLoEkqeojqnoLUD1aAZUmPZrVZv7IvrnK7/l4STFEY4wxRStcAglsLPdP4S4iNhI9HxrUqFTcIRhjTESESyC3iMh8EZkfsL0AuC9KsZUal3ZpnKvMZus1xsS6cL2wfgDeC1I+qDAfKCJ1cNZV34DTvjJSVXcFqXct0A1nHMp6VR3vllcE7gIeA+qrakph4omGV67uzrfLLFkYY0qXcAlkmKruzlkoIr8V8jOfBKaq6icichkwjhxJSUSa4vQA66aqKiILRGS6qiYBZwCfA08XMo6oenNQD1ZsP8jirQeYFTA/VkpaBt8t20HF8nH88dQmxRihMcbkT8gEEix5uOWFnR1wAPCEuz0HeD9Inf7AIlX1dYKdC1yM07D/MzirAsaSfh0b0a9jIx7+akW2BNLp4Sn+7T90PSnmrssYU3Z5GYmebyIyRUSWBHn9AWgA+IZlHwJqi0jORBZYx1evQQHiGCIiC0Vk4e7dQfNh1P0tyAqGPr/9fiiKkRhjTOFEZHEoVe0f6piIJON0Cz4A1AD2q2pGjmrJQOuA/RrAugLE8SbwJkBiYqLmUT0qqlUM/Ve+PzU9ipEYY0zheJnOfUiO/WcL+ZmTgF7u9lnuPiISJyIJbvkUoIdkPc/pBUwu5OeWCLWqVGBwr2Z8fcdZuY7NtrVDjDExxMsjrCEi8lcRKSciLwE3FfIzRwIXisiDwBU4jeXgzrUFoKrbcBrXn3cT1ttuAzoi0tw9F2CYiLQnxjz2x050aVqLiTdlnzNr/MwNTFu1i50HjxVTZMYY451ktVOHqCDSGLgSGAa8DXylqkujEFuRSkxM1IULFxZ3GNms3XWYfs/PDHps6ah+1Kxia4gYY4qXiCxS1cRgx7zcgbwL3Ab8E6gDnFd0oZVtbRpUC3nsjDHT+Mv4uVGMxhhj8sdLAqkO9FbVj3DuQs6LaERliIgw4uL2DO3fLtexo8cz+XXjvmKIyhhjvPGSQG7zjQlR1XTAlrQtQrf0bsXtfVpzeos6QY+/MHVtlCMyxhhvvCSQ5SJysYgMFpHBwNBIB1UWjfu/rkHLX5iaFOVIjDHGGy8JZDxwEXAt0A6nHcQUsSa1KnPz2S2YfPc5uY4N/XQpd320uBiiMsaY0LwkkE2qejcwTVUfwBmjYYpYXJzw4KUdOKVxjVzHPl20ja+X/l4MURljTGheEkhD98967iSHuUfAmajIyDxR3CEYY4yflwSyUkQG4IwEXwasiGxIZuq9vbmuVzPeGpy96/WYyatZsf0gz/6wppgiM8aYLCEHEopInKqWml95S+JAwrzsO5JO98d/DHps1rA+nFynSpQjMsaUNQUdSPiYiFQWkSo5XqMiFKfJoU7VCiGPfbF4exQjMcaY3MIlkJFACs606ikBr4ejEJfJQ7jkYowx0RAugQwAXgMGqGqc70XW5IcmCuaOOJ/fHu3PwMSm2cof/NKaoowxxStkAlHVyap6JxAvIi+LyIVu+fNRi87QuGZlqlYsx+g/dc51bM66PWzcc6QYojLGGA+9sFR1kptIeolI8KljTcRVKJf7n+qat3+lz7ifWLjJ5swyxkSflwWlOorIZ8ANwMTIh2RCGdC5Mac0rpGr/WPg+Lk0Hz6JvKbmN8aYohQygYhIexH5D/At8CPQVlXfEpHQi3qbiHr1mu5Mvvsc9h3JvvTtCTdvpKZnFkNUxpiyKtwdyAqgOc7su8eAq0TkOuCpKMRlwphwfdAu2fR8Yioz1+62ebOMMVFRLsyxicB7QcoPRiYU49W5beoHLT+SnsngCfMBePiyDtStVjGaYRljyphwCeQhd23ybETktwjGYzwoFx/Hee3qc1XPBG6ZuChond0paZZAjDERFe4R1isi8rWIXC8i/iliVXVPFOIyeXjvhp7079go5PFRX/7GzoPHmJW0O4pRGWPKknDjQP4EXAeUB/4jIp+JyEARqRy16EyeLuzgTJY84/7zspXP37SPM8ZMY9A781m8ZX8xRGaMKe3CPcJCVfcDbwFviUgjYCDwtYi8r6ofRCNAE96bg3pw7PgJysVLyDqfLtpGt4TaUYzKGFMWeBkH8iSAqu5U1ZdU9ULgPxGPzHgiIlSuEE/5+ND/lLsOHotiRMaYsiLsHYirs4i8CvwGvK+qR1Q1o6AfKCJ1gLHABqANMFJVdwWpdy3QDcgE1qvqeBER4F/AWpzk1wq4VVVtPg+ga9OarEtOoV/HRtlm6522OrkYozLGlFZeEshAVT0qIh2AV0VkL/CKqm4s4Gc+CUxV1U9E5DJgHDAosIK78uH9QDdVVRFZICLTcZLOBlV93K33OvB34NkCxlKqfHHbWYj7JMumezfGRJqXFQn7ikgPYATQC9gBXCkiIwv4mQOAue72HHc/p/7AIs2am2MucLGqZqpq4HTycThTzBucddVFnFdOzYdPovnwSew+nFYMkRljSiMvdyAfAIuBl4HrfKsUisiLoU4QkSlkraUeaBTQAGeNEYBDQG0RKZfjsVhgHV+9Bjk+oznQErjLwzWUWee0qcespKye12eMmUank2rw+a1nUi5Mu4kxxuTFSwJ5SFVfDiwQkfLAklAnqGr/UMdEJBmoDhwAagD7g7SpJAOtA/ZrAOsC3qMpzhQrf1HVkL9Si8gQYAhAQkJCqGql0pe3n8XCTfsYPWlVtvLME8rSbQd5e/ZGBvdqxrJtBzmjZd1iitIYE8tCronuryAyDngR8D0XUWBHQRvSReQNYHpAG8hAVR0kInFAU1Xd4iaIbwloAwGuVtUkEWkFPALcrqqHRORKVf08r8+NxTXRi8KMNcnc8O6CsHVm/7MPTWvb+urGmNzCrYnu5Q7kIuDPwGacyRX3AOVFZLSqflKAeEYCT4lIW5xeVL4VDrvgzL/VWVW3uYnreRHJBN52k0clYCawHWc8CkASkGcCKav6tGuQZ53ZSXv4a8+ydYdmjCk8LwnkY2C0eycQB9ynqs+IyLNAvhOIqu4D/hakfAnQOWD/A5z2l8A6x4Am+f3Msu7jIWdQu2oF+j0ffD2w2lUrsP3AUWat3W2JxBjjmZcEUtPXG0pVT4jISW75jsiFZYrS6Xm0cdwycRHVKpYjJS2DTk1q0qlJzShFZoyJZV664TQUkVdF5B4ReQ2oLyKdgPMjHJspYue0qQfA+EE9ch1LSXOatB77ZmVUYzLGxC4vdyA3ATcDHXFGo7/tnjc4gnGZCJh40+l51lm67QAAe206eGNMHrwkkJnAbar6WkBZGmDTh5RCaRknaD58EgBPXN6Ja05vVswRGWNKKi+PsJap6v98OyJigwZKgS9vP4vXr+nO+zf2DFnngS9WRDEiY0ys8ZJAtorIRSLSTEQSgH9GOigTeaeeXIuLOzemd9vgy+MCtKxflbdmbmD/kfQoRmaMiRVeBhLuAFYHFCWoaquIRhUBZXUgoRe+R1Yn1azE7yGmfn/j2u7M37ifhy49JehcW8aY0qmwAwlHqOp7AW92QVEFZkqGyXefwzdLfwfgtZ/WB63z9w+cp5iJzWvTpFZlXp2xjuf/cipVK3r5ChljSqM8//er6nsi0gWoD6wBpkU8KhNVpzSuwSmNa6CqIROIz9pdh/li8XZ+XLmLFdsP5jnGxBhTenlZkXAozlxYg3EWgBob6aBM8RARyodZGhdgzc7DbNmbCsC3y3b4p4nfcfBoNEI0xpQgXhrRq6lqH2Clqs7A6cJrSql5I/qydFQ/zm5dL+jxySt2smaXM9P+xHmb/eW9xkyPSnzGmJLDSwKJd//UHPumFKpbrSI1q5Sne0ItAO69sG2+3yM1PYPPF21DVdm6L5V/fraMY8czizpUY0wx89ICmiki3wNVRKQn8L+8TjCx795+7bj1vNZUrhDPcz+uzbP+gC6NAViXnMJ7v2zkg3lbOKlWZT5ZuJUvFm/n3Lb1/XWMMaWDl0b0h0WkH85060tV9cfIh2VKgsoVst9sDruoHU9/vyZoXVWlw6jvSU3PutN4a9YGpq9OBmDr/lSOHc9kT0qarT1iTCmR5ziQXCeIXK6qX0QonoixcSAFl55xgqPHM6lZubx/zEjFcnGkZZwo0Pu9cnU3Lu1yUt4VjTHFLtw4EC+9sG4WkWUiskFENuJMpmjKkArl4qhZuXy2sm/uPLvA73fHh4sLG5IxpgTw0gZyNXC+qu4BEJHrIhuSKclWP34Rizbvp23D6v6ygYlN+WThtmKMyhhTHLz0wprnSx6u5ZEKxpR8lcrHc1aOLr5dT66Vr/do36h63pWMMSWelzuQLiIyB/B1xekMBH0eZsqW1Y9fBMCwz5blOnZ++wb+BvRc5+08HNG4jDHR4XUcyEjgPfe1JILxmBhSqXw8lcrH8+zArnRuUpOp9/b2H9sXMIPvhzfnXsjKN4L9iLsSojEm9ni5A7lKVQ/4dkTklwjGY2JQ+fg4f6P657f24srX5zKsfzua1q5CUvJhzgwxqh3guR/XMuLi9hw8epy61SqSmp5BlQrlUFV+P3iMJrUqR+syjDH5FLIbr4i8B3ymqt8GlA0ArlXVq6ITXtGxbrzRk5aRScVy2ceQjPluFeNnbuCBS07hie9W+csrl4/nqDtK/dy29Zm5djef39qLhZv2M2byaqbd15tW9atFNX5jTJaCduPdqKrfisjrIvKZiLRV1UmALVNnwsqZPABGXHIKm8YO4Ke12dtFjgZMcTJz7W4AJszZxMwkZ3vb/qNs3ZfKaz+tI79jlowxkeVlJPqtIvK8qvoa0e1/sSmw285rzZx1e8PWSc84Qcoxp23kugnz/eU1K5e3NdqNKUHCJRANsV0oIlIHZ0r4DTjTw49U1V1B6l0LdAMygfWqOt4tHwtUAXYAvYD7A5KbKeHOal2Pizo2IuPECaauCt5L68eVu4J29f104TZLIMaUIOESyC0icqm7nSAiZwMCNAKeLMRnPglMVdVPROQyYBwwKLCCiDQF7ge6qaqKyAIRma6qSUAqziqJKiL/AIYCfytEPCbK3hjUA8haSjeYYF191yenRCwmY0z+hUsgP+B0281pUJCy/BgAPOFuzwHeD1KnP7BIsx56zwUuBpJU9bGAeq2BlYWMxxSTSzo34rvlOz3X79y0ZgSjMcbkV7gEMkxVd+csFJHf8npTEZkCNAxyaBTQAPD9enkIqC0i5VQ1cEBAYB1fvQYB798OGIazzO79YeIYAgwBSEhIyCtsE2U3n9OS75bv5MnLOzPyi7wnOPhlffi2E2NMdIXshRUsebjle4KV56jTX1VPDfL6GkgGfA+4awD7cyQPctTx1fM/MFfVNap6E/Alwe9gfPXeVNVEVU2sX79+XmGbKOueUJtNYwdw9ekJtGngdNV9+apuYc959Jvf+P2ALZ9rTEngZSR6UZuE0/gNcJa7j4jEiYjvNmEK0ENEfAt09wImu/WGBrzXRqBlxCM2EffDP85l8t3ncFnXrGnen7i8U656787ZxJljp7MuOYXJy3fw2+8HWbvrMImjf+Tg0eNs2ZvKqK9WkJFZsKnmjTHe5Xs9kEJ/oNML6ylgM9AKGK6qu0TkVGCiqnZ2612LM+dWJrA2oBfWJ0AScAToDjyvqnPy+lwbSBg7knYdZtwPa3j9mh60HPkdABec0iBkry2fhy7twLRVu/hl/V4++3svEpvXiUa4xpRq4QYSepnKJOebDVHVNwsajKruI0ivKVVdgjNRo2//A+CDIPUGFvSzTWxo07A64wc539f+HRsyO2kPic3r5JlAUo5l+AcmJh9O8/fy+uaOs60B3pgICJlARGQfcCBnMU57RIETiDH54Uskfcb9lGfdr5Zup3aVCgDc9u//+cs/WbjVEogxERCuDeQOVW2Z49UCuDNawRnj89UdZ+VZZ8PuI+xJSctVPnHe5kiEZEyZF64X1ochDlkXGBN1NSqVZ9PYAXx/zzlh623em5qr7JrTrQu3MZHgZU30viIy39ZENyVB+0Y18n3Ov3/dEoFIjDFeuvFehTMyfDzO3FXPRDQiY/LgW6DqqSs751Ezi83ka0zR85JA1qjqfsA3Wrx2hGMyJqwzW9dj09gB/OW0BN4a7DSyT733XP/xXi3r5jqnxYjvaD58EsczT5B5Qpm0bAfHM0+w3R2UeOjYceasy3OMrDEmgJduvL1FZBFQSUTexrkLMaZEuLBDQzaNHZCt7NE/dqTf8zOD1h/x3+V8tmhbtrIp95zLM1NWM3VVMgsfvIB61SpGLF5jShMvCeQvwAlgHnAz8HREIzKmgDaOuQRViIsTqlaI50h6Jg8OOIXRk7JWQMyZPAC27ktl1Q5n6rU9KWl8vmgbS7cd4LVrekQtdmNikZcEMkBVPwEQkZnA34F7IxqVMQUgIvgmv/ntsYsAuOm9BXme98CXy/HNfDJx7mZ/o/vvB45ykq3JbkxIXtpA2vs23NHimWHqGlOidG+Wd5PdrkNp/sQT2GNr+H/zniHYmLIs3Ej0u4F7gFoicj3OKPQM3MkPjYkFt53XioxMpXGtSgz7bFnIersP5x6AuO9I7jJjTJZwAwlfdEeeD/GNQlfVNqp6TxTjM6ZQRIS7L2jDwMST833uybWrRCAiY0qPPB9hqeqnInKBiNwnIudHIyhjIqlGJW9ziE5esZPVOw/lKj92PJO57uJWR9Iy+O33g0UanzGxwstI9IdwGs2bAUPdfWNizjd3nA3ABzefzuBezQD840gA2jasluuci16YRfPhkziSlsGqHYcYOH4ul748m6vemseG3Sn84+MlDHhpNilpOddEM6b08/KrWAVVvcS3IyJjIhiPMRHTuWlN/5iRDo1rcF67+vRp518pmWf+3JU/vuosLVOrSnkOpB73H+v48JRc77fz4DF+WuMs3Lk3JY1Obp3Vj19EpfLxEbsOY0oKL72wcva6sqXeTMwrFx/H+e0bIiJsHHMJG8dcQteTa/mPX39m8zzfY8m2A6S7/X/f+Hm9v3zyih1FHq8xJZGXBJIhIl+LyAsi8g02G68pZZzxI04/3oUPXsCvI/sy5bddeZ739Pdr/Nsfzd/q3/7X3ODTxx9Jy2DJ1pxL7BgTu7w0oo8GXgG2AS+5+8aUSvWqVaRhjUqs2pG78dyrvu0bBC2/bsJ8/vTqHA4dO843S3/3Py4zJlaFTCAi8rOIXAegqj+o6jhV/TF6oRlTfFY+1r/A527dl3WTrqocSctgw+4UFm7eDxn30ikAAByISURBVDh3Ind+tJilWw/4G9/TM0I/Gd53JJ2Fm/YVOB5jIiXcHch8VX0/Z6GIWOugKfWqVCjHz0PPY+GDF+T73I8XbmXR5n2oKh/O30LHh6dw/rM/+4/vP5LVOP/hr5tpPnwSbR+czO8Hgj8dHvTOr/z5jblknlD2pKSFrGdMtIVLINVF5GQRSQh8AdYLy5QJzepWpV61ivxfj6YA3N03ayLqx//Y0b/doHru2XuvfH0uLUZ8xwNfrMh1bOK8TQHbWe0l8zbs9W9/tWQ75z0zA1Xlt9+dx2mHjx0ncfRUzhw7nbQMm1HIFL9wCWQQsCjI69YoxGVMiTH68k7ceX5rhpzbks9v7UVCnSpc0b2p//gLfz01X+/30fytNKzhJJ3Ax12zk7LWIxnx3+Vs2pvK01OyGur3Hkn3b09dmZzv6zCmqIVLIK+pagNVrR/4wmbiNWVMxXLx3NevHVUrlqNHszrMHNaHqhXL8cJfTuWCUxpwZqt6/rpeuv+CM4FjTqe3rMP+I+m8PC2J1HTnDuP1n7K6B/cNeAw2d4MtfmWKX7iBhEHXAFXVtwrzgSJSBxgLbMBZnGqkqubqMyki1wLdcMahrFfV8TmOvw2cqqqJOc81Jhr+1K0Jf+rWBICkJy4GoNeY6QV+vzdnbuCfn3ubAbhyiIGKXy3ZzgtTk5h+X28+W7SND+dv4YvbzipwTMaEEy6B1BeR64I1pBfSk8BUVf1ERC4DxuE8LvMTkabA/UA3VVURWSAi01U1yT1+LXCkiOMypsDKxzs38/FeRlaFsH6396/0W7M28tasjQCsGX0R+48c5905Gxk/cwMAaRknGOrOPnzseKaNjDcREW423hsikDwABgBz3e057n5O/YFFquq7C5oLXAwgIqcAHYAvIhCbMYXy89A+ADz95y7+sgrlsv6bzRvR179dPl6K5DOf/n4NZ4yZ5k8e4HT99Uk+lMYv6/fwwbzgAxyNKahC/L4UmohMEZElQV5/ABoAh92qh4DaIpLzTiiwjq9eAxGpAvwTeMRjHENEZKGILNy9e3ehrskYLyqVj2fT2AHZpo9/e3Aia0dfzPT7etOoZiV/+U9usvFiaP92IY+9M3tjrrJfN2b16Hrq+9Vc/davPPjlCtbszPpvdfDoceZvdMaXvDlzPc2H21I/Jn8ikkBUtb+qnhrk9TWQDFR3q9YA9qtqzqlMA+v46iUD5wP7cRryrwYaichwEQk69FdV31TVRFVNrF+/fhFeoTF5W//kJXx+65mc27Y+FcrF0bK+M9vv/JF9WfpwP5rUqkx8nHBVz4Sg51/VMysJ/e2clv7tAV0a5/nZ//h4KfFxzh3OpOVZc3P9d/E20jNOMDtpD10f/YGB4+eSmp7Bk9+tBuDQseNB38+YYLwtjBBARK5X1fcK8ZmTgF7AVuAsdx8RiQOaquoWYApwp4iI+xirF/Cy2wbyrVv/PKC7qo4tRCzGREx8nNAjyJK6DWpk3YWsf9KZ6Pqj+Vty1fvraQmMurQj2w8cpUK5OB66tAPVK5ULu7JioMwTufvBtG9UnWvf/pX5ASPbA6ei33nwGAKs2nGYni3q+Ms37jnC10t+566+rf3zhhmTZwIRkUeAm4F0nGVtawDvFeIzRwJPiUhboBVOYzlAF2Ai0FlVt4nIOOB5EckE3vY1oLsxJeI0vDcWkeGWREysu/nsFrw9eyM3n92CEZecwrb9qTSrWxWA1g2cO5ebzm4BwENfriAtzNQn4bw6Yz3rklOylS3atN+//dg3K5m9zukiPOmus+l4Uk0A+oz7CYBrzkhg2/6j/PDbToZd1L5AMZjSQ7LaqUNUEPkW+IOqnnD3/6SqX0YjuKKUmJioCxcuLO4wjCm0HQeP0u+5mbx/U0+ueO2XInnPetUqsCclPVvZn3s0ZdAZzbjxvQX+QYw/Dz2P3s/8BMC/buzJuW3t0XBpJyKLQg2X8NIGssCXPFy2fqcxxahxzcosf7Q/3ROyHo/dc0HWNCvf3nl2vt8zZ/IAqFutAn98dU62EfCHj2U97sp5J+Pz/i+baD58EoetPaXU85JALhaRzSIyQ0RmAG9HOihjTP6c3boeSU9czLwRfenUpKa//IW/5G+alUDjf96Qq+zSl2fjawKpXbV8tmMn3DaX93/ZBMD2A0dJTc/gpWlJHM+0dehKIy8JZBNwLnA9cAPwYQTjMcbkw/onL+HDv51OYvM6lI+P83cT/uSWXvx0/3n+kfKNA7oPF5bvqfe7czZxxG2A/2rJdlqO/I6UtAx/gjmYepwOo6bw3I9r+WRh1oJbk5fvoPnwSWzea2OBY52XBaX+qqqb3dcm4N+RD8sY40V8nGSbi8unZ4s6NK/nNMKveuwiZg7LGnPSp11Wu8W4/+ta4M9etu0gHR+ewmUvz+aRr38D4Mwx0/yPuX5YmTVD0ddLfvdvP/ODM0Hkyt8PkZqewQNfLLfHXTEq3IJSl7p/jgp84axOaIyJEZUrxFM+Po6nruwMwKN/6MTPQ8/j771bcWX3Jv56E64v2LRyy7cfZH+qkwAOHcsg+bAzUWTgAMc4EVLchbQ2uFO2lIuPo8OoKfz71y3+cSgmtoS7AznN/bMbsDngZYs6GxODBiaezMrH+pNQtwrN6lZl+MXtERE+uaUX0+7rzfntGwLQv2PDoOcHNtTnV7l4odPDU/hmadadSOBdR85xMNZmEhvCzYX1sLt5l6q+73sB90QnNGNMURIRqlTIPfSrZ4s6tHJHyW8aO4DxgxLpllALcB5x/dldUOvvvVv5z8lv991ZSbmnn7/3k6X+7VNPruXf/nLxdto8MJnJy3cwYfZGmg+fxMvTknKd77M3JS3sksAmcvIcB1Ja2DgQY7zbsjeVxyet5MW/npot6cxYnUytKuW5vIjGn/hc0a0Jj/+pE72fmcH+1ONknlAG92rGv+ZmTQC5aawz7+q2/amc/dQMvrjtTA4fy2DwhPmIwMYxweZlNYVV2HEgxpgyJqFuFd4anJjrjqVP+wZ0S6jNnee3Dnv+I5d1yNfn/Xfxdq5951f2pKT7p2DZvDc1aN1f1jsTRY6ZvJr/LHAefeX8PTjwF+Oy8ktyccgzgYhIFxGpEY1gjDGx4d4L27LggQtY5y6kBXDn+a157Zru/KHrSVzncWXGQIu3ZG9e/XntbiqVd35EdW1aE1XlmSmr/Y3w8zfuo2ZlZyxKd/eRW2p6BgdS02kx4jv+/atz99JixHfc+sGifMdj8uZlMsUvgAtwplQ3xhhEhPrVnXXdZw3rw9DPlnLzOS2pWbk8l3R2Zgv+eMgZVCofzx9fnZPr/Nev6c6t//5fnp9z7LjTtrF020HOfmoG2w8czXZ8gTuPV9KuFPakpJE4eirVKzo/1sZNWUNzdz6xySt2Zjvvwud+5uJOjbi3X/Bp8ifM3sj+1HTuC3HcOLw8wvpMVf398UTE+yIGxphS7+Q6VfjPkF7+uwGf01vWpevJtfjwb6cDMOrSDvz75tMZmNiUiztnTUl/34VtPX1OzuQBWdOpHE7LIHH0VP82wP7U40FnLj5xQklKTuGl6esAWL87Jdegxse+XcnL7nGftbsOM2edrUUfyMsdSAsR+Q+wyt0/B5gRuZCMMaXJma3qsfDBC6hXzbljOau1M/Bx/gN9OXECXp4euodVYQVLOoHzeakqfZ/9GXAmh6xdpQKdm2ZNBZN5Qv3rqvR7fiaQ1ZgfTGp6RrZ2o08WbGXY58v46vaz6BrQ06y08HIH0hiYjDOlySZsHIgxJp98ySNQg+qVaFSzEle63YQj7fxxP9F8+CRemLbWXxZ4hzJ4wnwue2U265KzVm3ccfAoGZknsq2ZArB65yHGTVmTrYH+k4Vb6TBqCiu2Z803+4HbDjPOHX3vsy45xT93WCzzkkBuyDEO5LZIB2WMKTu6J9Rm6r29WfXYRfx0/3kAvHvDaTSs4SQd3wj6wtqwx3lM9e6cTf6yTxdty1Xvkhdn+7dfnJpE6wcm0+nhKdnqXPTCLF6ZsY6jxzP9Zb5kNG1Vsr/MN2PyoYC7nuXbDnLBcz8HXYrYi4Opx9kSoodatHlJICki8m8RWS4iE3EWlTLGmCLTukE1KleIp3m9qmwaO4A+7Rrw68gL2DR2AOe0CT5o0bdO/H+GnFGksaQHjIIPlmD2pKT5tx/9eiXnu4tt+ew8dMy/vWSr88CmZ3MnkWzdl+pfYniW257yzdLf2bY/e0L4fNE2flqTTDB/fWse5z6TvRXhaHomd320mK37optYvLSBPAF8BTwLtAHG4szKa4wxEXdSrcr+dofjmSe49YNF3HF+G049uRa393HGo1xwSkO6JdQiI1N5furacG9XaL7GeoCP3VmG35y5nia1KjvLD8cLf5+4iOf/cqo/gbw1ayNtGlRn2OdZj8zqV6vIT2uSufOjxZSPF0Zd1pG+7RtwUq3K3PepM0p/yj3nUqtKeRoGLIO8aofTITawveWThVv5eunvlIsTnri8M5UrxEf078DHy4qEQ1X1mYD9Eao6JuKRFTEbiW5M6aeqtBjxHZd1PYmuTWsyetIqnhvYNdu0KSVF3/YNmLY6+11Gq/pV+ehvZ9DzyWn+snJxwronL+HVGev4aP4Wtu13OgbMGtaHk+tUAaD3MzPYvDeVOIETCuMH9aB/x0ZFEmdhR6K3FpE67hvVA1oWSVTGGFPERIQloy7kuYFdufmclmwaO4DLu2XNOHztGQlhzz+5TuVIh+iXM3kArN99hKvempetLMNtbH9myhp/8gBIPpzGfne1yIbVnTsUX7v8LROdgZMXvziLB79cXuSx+3hJIO8DS0XkALAImBCxaIwxppBqValA+fisH20iwjvXJfLGtd25o48zo3CodpMPb3bKL+/WhLvc6VoGndEswhFnt3537oW2mg+flKvsytd/odvjP7InJY35m/blOj7625Ws2nGID+ZtyXWsqHhpA6kHnA6kq6qNojHGxJy+p2RNUe9rT/nvbWdyxWu/8OvIvsxK2sPybQc4uU4VVj9+EZXKO20IvpHqE+c53XGfuLwTD3yxIsrRh/fRr8ETxNsF7OWVH17aQBYD/VU1eJeAGGFtIMaYgjp87DgKVCwXR7sHv8927F839mTwhPkA1KlagX1H0kmoU4Utbo+oKfecS/8XZkY75GwCk2J+FbYNZEpg8hCRywsUhTHGxKjqlcpTo1J5KpaLZ/Y/+7D68Yv406knAdC2YXX6dWjIFd2aMOO+8zipZiUmXJ/Ik5d35qqeCbRrVJ2fh57H9/ecQ2Kz2u451aIa/5TfduZdqQC83IF8D9QiayqTzqGykacPdBrkxwIbcLoFj1TVXUHqXYuzGmImsF5Vx7vlbwDtA6reqap5thLZHYgxpiilpmeQtCslX1OUqCrjZ27gL4kn8+yPa/hg3hbeveE0bnh3AQD1q1dk9+G0bOf0almXuRv2Ziv7+o6z+MMr2Sep7NK0Jsu2HSSYp6/swsDTTvYcZ6BwdyBeEshknB/4PoNU9eYCRYI/AUxX1U9E5DJgoKoOylGnKfAt0E1VVUQWAFerapKIPKKqj+T3cy2BGGNKqldnrOP0FnVo3aAany3axk1nt6DFiO8AGHNFZ0b81/kdOec8XJknlFYjnXoLHriA056YytNXdvGPN5k5tA/rd6dwXrv6iBRsDHi4BOKlEX0KsF1V17lvNrdAUWQZgDM4EWAOTi+vnPoDizQru80FLgaSgOoi8gCQARwB3lDVjCDvYYwxMcE3IBLg5nOckRKByWLNzsPc3Tf3mvTxccIrV3cjPeME9atX9J9zOC2DMd+tIqFuFRLqVolY3F7uQH4GLlHV3H3LQp8zBWgY5NAo4FOgoaoeEJFywHGgfGASEJERbp173P3RAKr6oIh0B5apaoaIPA0cVtXHQ8QxBBgCkJCQ0GPz5s3BqhljjAmhsI3oswH/5C4icn1eJ6hqf1U9NcjrayAZqO5WrQHsD3IHEVjHVy/Zfe//BdSfDpwfJo43VTVRVRPr1w8+n44xxpiC8ZJALgQ2icgMEZkBPFTIz5wE9HK3z3L3EZE4EfENE50C9JCsh3a9cKaUR0SeCXivNkD2VV+MMcZEhZc2kA3A/wXsF7gB3TUSeEpE2gKtgPvd8i7ARJxeXttEZBzwvIhkAm+rqm/VmfoiMhZIBdoB9xYyHmOMMQWQZxtIaWG9sIwxJv8K1AYiIl+KyKU5ygaIyHdFHaAxxpjYE64NZLGqfisir4vIdBFpq6qTcLreGmOMKePybERX1VuBparqW6WlbDzzMsYYE1a4BKIhto0xxpiwCeSfIpIsIsnA393t3cCIKMVmjDGmBAvXjfdN4IUcZQLcHrlwjDHGxIpwCWSYqh7PWehOM2KMMaaMC/kIK1jycMtt4kJjjDGepjIxxhhjcrEEYowxpkDynAtLRCoAtwDlgV+BpFhfH90YY0zhebkDeR6oAyQAO4DHIhqRMcaYmOAlgWxS1UeBHaq6Adge4ZiMMcbEAC8JpKWIVARUROIIvtKgMcaYMsbLeiA/ABtxpjMZAvwjohEZY4yJCXkmEFX9wl2JsDWwE9gT8aiMMcaUeHk+whKRYap6QFUXAhWBdyIfljHGmJIu5B2Iuz55c6C9iJzrFsdhM/MaY4wh/COsbsCfgFNxJlEEyAS+jXRQxhhjSr6QCURVvwK+EpHTVHVBFGMyxhgTA7ysSJgteYjIkMiFY4wxJlZ4mcpkP7Af5zFWfeAAzlohxhhjyjAv40CGqOqnACJSCbgmsiEZY4yJBV7GgXwasH1MRFoU9kNFpA4wFtgAtAFGququIPWuxWnMzwTWq+p4t7wGcA9wCOgBzFXV1woblzHGGO+8PMKaQVbX3RrAkiL43CeBqar6iYhcBowDBuX43KbA/UA3VVURWSAi01U1ya0/RlU3urMFtyyCmIwxxuSDl0dY84A33O3DqrqvCD53APCEuz0HeD9Inf7AIlX1Ja+5wMUisg64EPjVvROphjNjsDHGmCjy8ggr2xroInKGqs7L6zwRmULwiRdHAQ2Aw+7+IaC2iJTLsVxuYB1fvQbuqznOuiQzReRm4BXg+iAxDMGZvwsgRUTW5BV3CPUoe1O42DWXDXbNZUNhrrlZqAPhRqJPCFYMdAYS8/pEVe0f5r2Tgeo4PbpqAPuDrLWejDP/lk8NYB1OIgFncSuA2cCDIWJ4kyLoMSYiC1U1z2suTeyaywa75rIhUtcc7g7kBDAxSPmgIGX5NQnoBWwFznL3caeLb6qqW4ApwJ0iIu5jrF7Ay6p6VETm4rR7rMLJjmuLICZjjDH5EC6B3KOqKb4dEamrqntFZFERfO5I4CkRaQu0wmksB+iCk7Q6q+o2ERkHPC8imcDbbgM6wM3AfSKyHugA3F4EMRljjMmHcFOZpACIyJnAx0BNd1DhX3EatAvMbYj/W5DyJTiPyHz7HwAfBKm3EieJREtZHDhp11w22DWXDRG5Zsnq5BSigsh44CFVTRaRRsBoVY3mD29jjDElkJclbZNUNRlAVXfiNGQbY4wp47yMA2knIlfgjBpvhTNyvMwQkQuAK3B6hamqPlrMIRWIiLQCRgP/A5oCe1X1sXCzAojIUJzeb7WBH1T1a7f8VJx2p4043arvD9KLrkQQkco4PfZ+UNX73el4xgHbca53rKqudeuGmvmgOfAQzi9PzYH7AtsHSxoRaQdcBRwFegOP4Hx/c12D23HlSSAFp0PKO75u+rH03Xe/q81xuqq2AW4CKlOKvtu+J0BAV1U9zS0rsu9zuO9CSKoa9gU0Bj4EVuC0RzTO65zS8gKquH/JFd39z4G+xR1XAa/lNOCPAfsrcaaBeQMY6JZdBkx0t08HvnO3ywNJQC2crtwrgEbusWeBm4r7+sJc97M4A1XHufvDgWHudmdglrvdFGeWBd9j3QVAG3f7e6Cnu30n8HhxX1eY643H6dUY5+43xpkENeg14LRpvuZu18Hp0RgfS999oBGwL+Cav8KZs69UfbeBP7vXsTCgrMi+z6G+C+Fi8jKd+w5VvVpVOwF3qOqOvM4pRXoBm1U1zd2fgzOKPuao6gJ11njxiQOO4FyPr1NE4PVd6itX1eM4XabPxek+XVmdx5k5zylRRGQQTnwbA4r916uqy4Gu7owGoWY+KA/0wfkPCCX4el2n4fwgvFNERuD8wDlA6GsI/PvYBxwDOhJb3/1UIB3njgKc2Sl+o5R9t1X1M7IProai/T6H+i6E5GVN9NdE5AwRuR1Y7HatLStCjYaPaSJyOTBFVVcTYlYAws8EUOL/TkSkA3CKqv43x6H8Xlc94GjAf8QSeb0BmuH88H9PVcfg/GC8n9DXENP/zgCqeggYCnwsIu8B23DunkrldzuHovw+5/v6vTSib1bnOdggnGx00MM5pYVvxLxPDbcsZolIH5zfQP7hFgVeY+CsAKGuPVb+Ti4HjonIcOBsoKeI3EP+r2sPUFlEJEd5SXUIWK2qvv+ns4FOhL6GWP939rVbDAUGqOr1OP9moyi93+1ARfl9zvf1e0kg9UTkHJxGmFQP9UuTuUAzEano7vtHzcciERmAc2t7N9BIRHqRNSsAZL++b33l7m9tHYCZOA2SR90GvZznlBiq+oSqPqaqY3F+iM5X1RcIuF4R6QwsdX+DnQL0CPiP1QuY7D7imIHzaAhK6PUG+BWoKyLx7n4znMc5oa4h8O+jDlDJrR9L3/0mwD7NauzegXMdpfK7nUNRfp9DfRdC8jIO5DacQX/XAycDf1DVMrOsrYhciNN4tRs4riW4J0o4ItID+BlY6BZVBV4FvgaeAjbj9LIbrtl7qtR2X5M1e0+VO91z6lCCeqrkJCJX4vSqqYBzvV/i9FrZgTPX2pOavddKIk6vlbWavdfKKJwfMAnAvVqye2FdDpyP851NwPm3akiQa3B73ozBaUdIAN7SrF5YMfHdd5PlSzjP7A/g3HHdA6RRir7bItIbGAxcBLyO08gPRfR9DvddCBlTXgkkIPh6qlrWZrA0xhgTgpdG9H4isgPYICI7RKRfFOIyxhhTwnlpA7kVOFVVawDdcW7vjDHGlHFeEsivvueG7hiQhXnUN8YYUwaEW1DqXHezlojciNPg0pLgqwwaY4wpY0I2orvrfizFGdUa6DJVrRfpwIwxxpRs4SZTvEtV5wQWiEh3nK68xpgQRGQW7ngMnMkI33IPNcH5pe2vxRWbMUXJyziQKsDVwC04fab3qurpUYjNmJgkIjeo6rsi0gn4VlWb+8pxphjx1nfemBIuZCO6iHQTkTeATThTQi9R1dY4A1mMMSGo6rshDlXHndhRRG4QkZ0iMlREJorIZBEZKCLviMhMd0I8RKSjiPzLrfeOiLSM1nUYk5dwvbBm4oxW7qCqg3AmKENV10QjMGNKG1V9KWD7XWA18D/3/1caUF1VbwIWAxe6Vd8G3lDVZ4CJZI0+NqbYhWsDOQlnTv3h7jNdL11+jTH5s97980DA9n6yJrXrAvRze0VWxlnsx5gSIWQCUdXDOAuyICI9gWoi8hDQQlVvjFJ8xpR1S4H/quoyd2LDy4s7IGN8vCxpi6rOB+aLSHXcpGKMCc1dSncIUFNEblTVCe7EpDVF5CqcabWbAdeLyNc4dxqDROR3nDU8OovIZJylWe8TkY04PSA/KI7rMSYYz5Mp+k8QqaCq6RGKxxhjTIzIdwIxxhhjwBrGjTHGFJAlEGOMMQViCcQYY0yBWAIxxhhTIJZAjDHGFIglEGOMMQViCcQYY0yB/D//bnnJjmMTygAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "total_true_energy = (\n",
    "    jax.vmap(kinetic_energy, 0, 0)(new_dataset['x'][:]) + \\\n",
    "    jax.vmap(potential_energy, 0, 0)(new_dataset['x'][:])\n",
    ")\n",
    "total_predicted_energy = (\n",
    "    jax.vmap(kinetic_energy, 0, 0)(pred_tall[:]) + \\\n",
    "    jax.vmap(potential_energy, 0, 0)(pred_tall[:])\n",
    ")\n",
    "scale=29.4\n",
    "\n",
    "# translation = jnp.min(total_true_energy) + 1\n",
    "# total_true_energy -= translation\n",
    "# total_predicted_energy -= translation\n",
    "\n",
    "plt.plot(\n",
    "    (total_predicted_energy-total_true_energy)/scale\n",
    ")\n",
    "\n",
    "plt.ylabel('Absolute Error in Total Energy/Max Potential Energy')\n",
    "plt.xlabel('Time')\n",
    "plt.ylim(-0.06, 0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [],
   "source": [
    "# np.save('baseline_dblpend_energy.npy', total_predicted_energy)\n",
    "# np.save('baseline_dblpend_prediction.npy', pred_tall)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = jax.random.PRNGKey(int(1e9))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(100000, 4)\n"
     ]
    }
   ],
   "source": [
    "batch_data = get_derivative_dataset(rng)[0][:100000], get_derivative_dataset(rng)[1][:100000]\n",
    "print(batch_data[0].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeviceArray(0.07387472, dtype=float32)"
      ]
     },
     "execution_count": 110,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss(best_params, batch_data, 0.0)/len(batch_data[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_loss = np.inf\n",
    "best_params = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running 0\n",
      "Cur best 0.5416298\n",
      "iteration=1, train_loss=3.281229, test_loss=3.362365\n",
      "iteration=2, train_loss=3.338815, test_loss=3.414896\n",
      "iteration=3, train_loss=3.173423, test_loss=3.245358\n",
      "iteration=5, train_loss=2.876028, test_loss=2.953962\n",
      "iteration=8, train_loss=2.617690, test_loss=2.711931\n",
      "iteration=12, train_loss=2.461916, test_loss=2.569056\n",
      "iteration=14, train_loss=2.420607, test_loss=2.529519\n",
      "iteration=16, train_loss=2.393510, test_loss=2.502541\n",
      "iteration=24, train_loss=2.345873, test_loss=2.449337\n",
      "iteration=69, train_loss=2.283767, test_loss=2.380207\n",
      "iteration=71, train_loss=2.279208, test_loss=2.375960\n",
      "iteration=100, train_loss=2.052558, test_loss=2.182241\n",
      "iteration=126, train_loss=2.150693, test_loss=2.242196\n",
      "iteration=138, train_loss=2.043466, test_loss=2.179271\n",
      "iteration=139, train_loss=2.103001, test_loss=2.250430\n",
      "iteration=159, train_loss=1.884455, test_loss=2.035448\n",
      "iteration=200, train_loss=2.981948, test_loss=3.082102\n",
      "iteration=300, train_loss=2.960814, test_loss=3.060463\n",
      "iteration=400, train_loss=2.960811, test_loss=3.060462\n",
      "iteration=500, train_loss=2.960812, test_loss=3.060462\n",
      "iteration=600, train_loss=2.960812, test_loss=3.060462\n",
      "iteration=700, train_loss=2.960812, test_loss=3.060462\n",
      "iteration=800, train_loss=2.960811, test_loss=3.060461\n",
      "iteration=900, train_loss=2.960811, test_loss=3.060461\n",
      "Running 1\n",
      "Cur best 0.5416298\n",
      "iteration=1, train_loss=3.317755, test_loss=3.405537\n",
      "iteration=7, train_loss=3.069434, test_loss=3.143141\n",
      "iteration=9, train_loss=2.917306, test_loss=3.000417\n",
      "iteration=15, train_loss=2.608413, test_loss=2.708017\n",
      "iteration=18, train_loss=2.543802, test_loss=2.646056\n",
      "iteration=32, train_loss=2.374668, test_loss=2.479213\n",
      "iteration=39, train_loss=2.299356, test_loss=2.407058\n",
      "iteration=43, train_loss=2.278494, test_loss=2.383728\n",
      "iteration=48, train_loss=2.254509, test_loss=2.356943\n",
      "iteration=56, train_loss=2.220623, test_loss=2.325905\n",
      "iteration=79, train_loss=2.060705, test_loss=2.196907\n",
      "iteration=81, train_loss=2.041281, test_loss=2.188952\n",
      "iteration=100, train_loss=2.084676, test_loss=2.204079\n",
      "iteration=125, train_loss=2.075690, test_loss=2.199195\n",
      "iteration=200, train_loss=3.000287, test_loss=3.060724\n",
      "iteration=300, train_loss=2.658924, test_loss=2.740549\n",
      "iteration=400, train_loss=2.534858, test_loss=2.626127\n",
      "iteration=500, train_loss=2.471405, test_loss=2.568029\n",
      "iteration=600, train_loss=2.436610, test_loss=2.535066\n",
      "iteration=700, train_loss=2.397138, test_loss=2.495735\n",
      "iteration=800, train_loss=2.320425, test_loss=2.424311\n",
      "iteration=900, train_loss=2.121181, test_loss=2.235008\n",
      "iteration=919, train_loss=2.059577, test_loss=2.177489\n",
      "Running 2\n",
      "Cur best 0.5416298\n",
      "iteration=1, train_loss=3.440591, test_loss=3.480830\n",
      "iteration=2, train_loss=2.827830, test_loss=2.907219\n",
      "iteration=9, train_loss=2.746610, test_loss=2.831154\n",
      "iteration=13, train_loss=2.650030, test_loss=2.734251\n",
      "iteration=18, train_loss=2.537922, test_loss=2.626945\n",
      "iteration=40, train_loss=2.310189, test_loss=2.429695\n",
      "iteration=42, train_loss=2.259042, test_loss=2.373740\n",
      "iteration=50, train_loss=2.158052, test_loss=2.284200\n",
      "iteration=100, train_loss=2.599460, test_loss=2.704235\n",
      "iteration=134, train_loss=2.432410, test_loss=2.542955\n",
      "iteration=200, train_loss=2.345119, test_loss=2.447702\n",
      "iteration=240, train_loss=2.313213, test_loss=2.408730\n",
      "iteration=270, train_loss=2.267631, test_loss=2.355745\n",
      "iteration=300, train_loss=2.167197, test_loss=2.254341\n",
      "iteration=321, train_loss=2.142529, test_loss=2.229394\n",
      "iteration=400, train_loss=1.977052, test_loss=2.082325\n",
      "iteration=466, train_loss=1.770378, test_loss=1.895221\n",
      "iteration=500, train_loss=1.714534, test_loss=1.860020\n",
      "iteration=504, train_loss=1.720593, test_loss=1.862275\n",
      "iteration=559, train_loss=1.666239, test_loss=1.814709\n",
      "iteration=582, train_loss=1.630734, test_loss=1.769616\n",
      "iteration=600, train_loss=1.634274, test_loss=1.775872\n",
      "iteration=700, train_loss=1.544320, test_loss=1.685645\n",
      "iteration=712, train_loss=1.545025, test_loss=1.697033\n",
      "iteration=800, train_loss=1.527132, test_loss=1.672148\n",
      "iteration=808, train_loss=1.677876, test_loss=1.650595\n",
      "iteration=848, train_loss=1.447197, test_loss=1.600824\n",
      "iteration=900, train_loss=1.371689, test_loss=1.529193\n",
      "iteration=941, train_loss=1.347298, test_loss=1.505083\n",
      "iteration=1000, train_loss=1.319070, test_loss=1.476217\n",
      "iteration=1051, train_loss=1.266694, test_loss=1.438258\n",
      "iteration=1175, train_loss=1.226607, test_loss=1.381197\n",
      "iteration=1188, train_loss=1.246742, test_loss=1.393102\n",
      "iteration=1226, train_loss=1.175629, test_loss=1.330247\n",
      "iteration=1481, train_loss=1.119713, test_loss=1.259370\n",
      "iteration=1500, train_loss=1.069470, test_loss=1.215181\n",
      "iteration=1611, train_loss=1.071569, test_loss=1.210487\n",
      "iteration=1736, train_loss=1.024607, test_loss=1.155848\n",
      "iteration=1839, train_loss=1.002461, test_loss=1.130287\n",
      "iteration=2000, train_loss=0.970661, test_loss=1.104429\n",
      "iteration=2022, train_loss=0.947728, test_loss=1.085750\n",
      "iteration=2095, train_loss=0.972037, test_loss=1.094629\n",
      "iteration=2351, train_loss=0.924950, test_loss=1.063900\n",
      "iteration=2500, train_loss=0.945857, test_loss=1.077084\n",
      "iteration=2568, train_loss=0.932887, test_loss=1.063079\n",
      "iteration=2901, train_loss=0.957002, test_loss=1.103113\n",
      "iteration=2992, train_loss=0.892175, test_loss=1.031748\n",
      "iteration=3000, train_loss=0.901037, test_loss=1.032903\n",
      "iteration=3302, train_loss=0.900294, test_loss=1.039103\n",
      "iteration=3500, train_loss=0.931067, test_loss=1.058354\n",
      "iteration=3940, train_loss=0.926068, test_loss=1.055898\n",
      "iteration=4000, train_loss=0.893681, test_loss=1.028090\n",
      "iteration=4500, train_loss=0.896116, test_loss=1.031149\n",
      "iteration=4533, train_loss=0.897025, test_loss=1.026571\n",
      "iteration=5000, train_loss=0.901224, test_loss=1.030433\n",
      "iteration=5500, train_loss=0.935792, test_loss=1.066033\n",
      "iteration=5503, train_loss=0.875628, test_loss=1.012658\n",
      "iteration=6000, train_loss=0.886833, test_loss=1.021481\n",
      "iteration=6500, train_loss=0.868955, test_loss=1.005378\n",
      "Running 3\n",
      "Cur best 0.5416298\n",
      "iteration=1, train_loss=164639319464549901665380073472.000000, test_loss=3081212275934353162240.000000\n",
      "iteration=3, train_loss=5.067006, test_loss=5.148108\n",
      "iteration=10, train_loss=2.743804, test_loss=2.803879\n",
      "iteration=11, train_loss=2.773528, test_loss=2.827923\n",
      "iteration=12, train_loss=2.837285, test_loss=2.890387\n",
      "iteration=16, train_loss=3.127988, test_loss=3.174890\n",
      "iteration=100, train_loss=4.512383, test_loss=4.532230\n",
      "iteration=200, train_loss=4.513072, test_loss=4.532910\n",
      "iteration=300, train_loss=4.513071, test_loss=4.532909\n",
      "iteration=400, train_loss=4.513073, test_loss=4.532909\n",
      "iteration=500, train_loss=4.513072, test_loss=4.532909\n",
      "iteration=600, train_loss=4.513072, test_loss=4.532908\n",
      "iteration=700, train_loss=4.513072, test_loss=4.532909\n",
      "iteration=800, train_loss=4.513072, test_loss=4.532908\n",
      "iteration=900, train_loss=4.513072, test_loss=4.532908\n",
      "Running 4\n",
      "Cur best 0.5416298\n"
     ]
    }
   ],
   "source": [
    "for _i in range(1000):\n",
    "    print('Running', _i)\n",
    "    print('Cur best', str(best_loss))\n",
    "\n",
    "    init_random_params, nn_forward_fn = extended_mlp(args)\n",
    "    import HyperparameterSearch\n",
    "    HyperparameterSearch.nn_forward_fn = nn_forward_fn\n",
    "    _, init_params = init_random_params(rng+1, (-1, 4))\n",
    "    rng += 1\n",
    "    model = (nn_forward_fn, init_params)\n",
    "    opt_init, opt_update, get_params = optimizers.adam(3e-4)##lambda i: jnp.select([i<10000, i>= 10000], [args.lr, args.lr2]))\n",
    "    opt_state = opt_init(init_params)\n",
    "    from jax.tree_util import tree_flatten\n",
    "    from HyperparameterSearch import make_loss, train\n",
    "    loss = make_loss(args)\n",
    "    from copy import deepcopy as copy\n",
    "    train(args, model, data, rng);\n",
    "    from jax.tree_util import tree_flatten\n",
    "\n",
    "    @jax.jit\n",
    "    def update_derivative(i, opt_state, batch, l2reg):\n",
    "        params = get_params(opt_state)\n",
    "        param_update = jax.grad(loss, 0)(params, batch, l2reg)\n",
    "        leaves, _ = tree_flatten(param_update)\n",
    "        infinities = sum((~jnp.isfinite(param)).sum() for param in leaves)\n",
    "\n",
    "        def true_fun(x):\n",
    "            #No introducing NaNs.\n",
    "            return opt_update(i, param_update, opt_state), params\n",
    "\n",
    "        def false_fun(x):\n",
    "            #No introducing NaNs.\n",
    "            return opt_state, params\n",
    "\n",
    "        return jax.lax.cond(infinities==0, 0, true_fun, 0, false_fun)\n",
    "\n",
    "\n",
    "    best_small_loss = np.inf\n",
    "    (nn_forward_fn, init_params) = model\n",
    "    data = {k: jax.device_put(v) for k,v in data.items()}\n",
    "    iteration = 0\n",
    "    train_losses, test_losses = [], []\n",
    "    lr = args.lr\n",
    "    opt_init, opt_update, get_params = optimizers.adam(lr)\n",
    "    opt_state = opt_init(init_params)\n",
    "    bad_iterations = 0\n",
    "    offset = 0\n",
    "    \n",
    "    while iteration < 20000:\n",
    "        iteration += 1\n",
    "        rand_idx = jax.random.randint(rng, (args.batch_size,), 0, len(data['x']))\n",
    "        rng += 1\n",
    "\n",
    "        batch = (data['x'][rand_idx], data['dx'][rand_idx])\n",
    "        opt_state, params = update_derivative(iteration+offset, opt_state, batch, args.l2reg)\n",
    "        small_loss = loss(params, batch, 0.0)\n",
    "\n",
    "        new_small_loss = False\n",
    "        if small_loss < best_small_loss:\n",
    "\n",
    "            best_small_loss = small_loss\n",
    "            new_small_loss = True\n",
    "        \n",
    "        if jnp.isnan(small_loss).sum() or new_small_loss or (iteration % 500 == 0) or (iteration < 1000 and iteration % 100 == 0):\n",
    "            params = get_params(opt_state)\n",
    "            train_loss = loss(params, (data['x'], data['dx']), 0.0)/len(data['x'])\n",
    "            train_losses.append(train_loss)\n",
    "            test_loss = loss(params, (data['test_x'], data['test_dx']), 0.0)/len(data['test_x'])\n",
    "            test_losses.append(test_loss)\n",
    "            \n",
    "            if iteration >= 1000 and test_loss > 1.5:\n",
    "                #Only good seeds allowed!\n",
    "                break\n",
    "\n",
    "            if test_loss < best_loss:\n",
    "                best_loss = test_loss\n",
    "                best_params = copy(params)\n",
    "                bad_iterations = 0\n",
    "                offset += iteration\n",
    "                iteration = 0 #Keep going since this one is so good!\n",
    "\n",
    "            if jnp.isnan(test_loss).sum():\n",
    "                break\n",
    "                lr = lr/2\n",
    "                opt_init, opt_update, get_params = optimizers.adam(lr)\n",
    "                opt_state = opt_init(best_params)\n",
    "                bad_iterations = 0\n",
    "\n",
    "            print(f\"iteration={iteration}, train_loss={train_loss:.6f}, test_loss={test_loss:.6f}\")\n",
    "\n",
    "        bad_iterations += 1\n",
    "    \n",
    "    import pickle as pkl\n",
    "    if best_loss < np.inf:\n",
    "        pkl.dump({'params': best_params, 'args': args},\n",
    "             open('params_for_loss_{}_nupdates=1.pkl'.format(best_loss), 'wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "import importlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "import lnn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<module 'lnn' from '../lnn.py'>"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "importlib.reload(lnn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "from lnn import lagrangian_eom_rk4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 236,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeviceArray(1.8673568, dtype=float32)"
      ]
     },
     "execution_count": 236,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss(best_params, (data['test_x'], data['test_dx']), 0.0)/len(data['test_x'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "main2",
   "language": "python",
   "name": "main2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
